Case 1. pretrained CNN model layer 직접 수정
pretrained CNN model layer 직접 수정하여 input channel을 변경해 보도록 하겠습니다.
Problem
일반적인 classification task를 위한 cnn 기반 모델의 인풋은 3 channel (RGB)로 들어가게 됩니다. 하지만 데이터를 가공시키고 원하는 input channel 이 3채널이 아닐때(4채널 or 그 이상 채널이 들어가야 할 때) pretrained cnn 모델을 가져와서 사용하고 싶을때 본 방법은 유용합니다.
저는 se_resnext101 네트워크에 ImageNet dataset이 pretrained 모델을 가져와서 사용하였습니다.
그리고 제가 가공한 이미지 데이터는 9*256*256 였기에 첫번째 conv layer 에서 input channel 을 9로 변경해 주었습니다!
두 가지 방법을 공유 해드리려고 합니다!
Solution.1 (제가 사용한 코드)
원하는 conv layer 만 수정하는 방법 (layer 수정, weight 수정)
class se_resnext101_32x4d(nn.Module):
def __init__(self):
super(se_resnext101_32x4d, self).__init__()
self.model_ft = pretrainedmodels.__dict__['se_resnext101_32x4d'](num_classes=1000, pretrained='imagenet')
# input channel 변경
prev_w = self.model_ft.layer0.conv1.weight #기존 pretrained 된 모델 weight
self.model_ft.layer0.conv1 = nn.Conv2d(9, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) # 새로 정의하는 conv
self.model_ft.layer0.conv1.weight = nn.Parameter(torch.cat((prev_w, torch.zeros(64, 6, 7, 7)), dim=1)) # 새로 정의하는 weight
num_ftrs = self.model_ft.last_linear.in_features
self.model_ft.avg_pool = nn.AdaptiveAvgPool2d((1,1))
self.model_ft.last_linear = nn.Sequential(nn.Linear(num_ftrs, 6, bias=True))
def forward(self, x):
x = self.model_ft(x)
return x
이 세 줄만 추가하여 기존 pretrained 모델 로드하고 기존의 weight 와 바뀔 weight의 shape을 맞춰서 사용하시면 됩니다!
prev_w = self.model_ft.layer0.conv1.weight #기존 pretrained 된 모델 weight
self.model_ft.layer0.conv1 = nn.Conv2d(9, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
self.model_ft.layer0.conv1.weight = nn.Parameter(torch.cat((prev_w, torch.zeros(64, 6, 7, 7)), dim=1))
Solution. 2
feautre를 list로 받아와서 수정하는 방법.
출처 : https://discuss.pytorch.org/t/how-to-change-no-of-input-channels-to-a-pretrained-model/19379/6
class RGBMaskEncoderCNN(nn.Module):
def __init__(self):
super(RGBMaskEncoderCNN, self).__init__()
self.alexnet = models.alexnet(pretrained=True)
new_classifier = nn.Sequential(*list(self.alexnet.classifier.children())[:-1])
self.alexnet.classifier = new_classifier
# get the pre-trained weights of the first layer
pretrained_weights = self.alexnet.features[0].weight
new_features = nn.Sequential(*list(self.alexnet.features.children()))
new_features[0] = nn.Conv2d(4, 64, kernel_size=11, stride=4, padding=2)
# For M-channel weight should randomly initialized with Gaussian
new_features[0].weight.data.normal_(0, 0.001)
# For RGB it should be copied from pretrained weights
new_features[0].weight.data[:, :3, :, :] = nn.Parameter(pretrained_weights)
self.alexnet.features = new_features
def forward(self, images):
"""Extract Feature Vector from Input Images"""
features = self.alexnet(images)
return features
Case 2. load_state_dict() 에 strict = False로 설정.
Problem
기존 모델을 선언하고, torch.load() 해서 pretrained 해놓은 모델의 weight 값을 가져오는 과정에서 기존 모델에 맞춰줘야 weight 값이 담길 수 있는데 다르면 weight shape 에러가 나게 됩니다.
예시) 기존 모델 last linear(classifier) = 6인데, 내가 pretrained 해논 모델의 last linear 가 = 1 일때
Solution
state_load_dict() 함수 내부에 strict 옵션이 있는데 이 옵션을 false로 설정하면 같은 shape을 가지는 layer만 flexible하게 가져올 수 있다. 즉, 마지막 classifier 단의 shape이 안맞으면 이 부분은 pretrained 모델의 shape을 따르지 않고, 우리가 설정한 모델의 shape을 따른다.
model = model.cuda() # 기존의 model
pretrained_model = torch.load("...") #새로 pretrained 한 model
model.state_load_dict(pretrained_model, strict=False)
'Artificial Intelligence > Computer Vision' 카테고리의 다른 글
Optical Flow 개념 및 알고리즘 종류 (0) | 2021.09.28 |
---|---|
[의료 이미지] CT image 이해하기 (feat. Dicom 파일) (0) | 2021.06.22 |
Action speaks louder than words. 하루 하루의 기록을 습관화 합니다 📖
포스팅이 좋았다면 "좋아요❤️" 또는 "구독👍🏻" 해주세요!