Gaegul's devlog
pretrained model layer 수정하기 본문
Case 1. pretrained CNN model layer 직접 수정
pretrained CNN model layer 직접 수정하여 input channel을 변경해 보도록 하겠습니다.
Problem
일반적인 classification task를 위한 cnn 기반 모델의 인풋은 3 channel (RGB)로 들어가게 됩니다. 하지만 데이터를 가공시키고 원하는 input channel 이 3채널이 아닐때(4채널 or 그 이상 채널이 들어가야 할 때) pretrained cnn 모델을 가져와서 사용하고 싶을때 본 방법은 유용합니다.
저는 se_resnext101 네트워크에 ImageNet dataset이 pretrained 모델을 가져와서 사용하였습니다.
그리고 제가 가공한 이미지 데이터는 9*256*256 였기에 첫번째 conv layer 에서 input channel 을 9로 변경해 주었습니다!
두 가지 방법을 공유 해드리려고 합니다!
Solution.1 (제가 사용한 코드)
원하는 conv layer 만 수정하는 방법 (layer 수정, weight 수정)
class se_resnext101_32x4d(nn.Module):
def __init__(self):
super(se_resnext101_32x4d, self).__init__()
self.model_ft = pretrainedmodels.__dict__['se_resnext101_32x4d'](num_classes=1000, pretrained='imagenet')
# input channel 변경
prev_w = self.model_ft.layer0.conv1.weight #기존 pretrained 된 모델 weight
self.model_ft.layer0.conv1 = nn.Conv2d(9, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) # 새로 정의하는 conv
self.model_ft.layer0.conv1.weight = nn.Parameter(torch.cat((prev_w, torch.zeros(64, 6, 7, 7)), dim=1)) # 새로 정의하는 weight
num_ftrs = self.model_ft.last_linear.in_features
self.model_ft.avg_pool = nn.AdaptiveAvgPool2d((1,1))
self.model_ft.last_linear = nn.Sequential(nn.Linear(num_ftrs, 6, bias=True))
def forward(self, x):
x = self.model_ft(x)
return x
이 세 줄만 추가하여 기존 pretrained 모델 로드하고 기존의 weight 와 바뀔 weight의 shape을 맞춰서 사용하시면 됩니다!
prev_w = self.model_ft.layer0.conv1.weight #기존 pretrained 된 모델 weight
self.model_ft.layer0.conv1 = nn.Conv2d(9, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
self.model_ft.layer0.conv1.weight = nn.Parameter(torch.cat((prev_w, torch.zeros(64, 6, 7, 7)), dim=1))
Solution. 2
feautre를 list로 받아와서 수정하는 방법.
출처 : https://discuss.pytorch.org/t/how-to-change-no-of-input-channels-to-a-pretrained-model/19379/6
class RGBMaskEncoderCNN(nn.Module):
def __init__(self):
super(RGBMaskEncoderCNN, self).__init__()
self.alexnet = models.alexnet(pretrained=True)
new_classifier = nn.Sequential(*list(self.alexnet.classifier.children())[:-1])
self.alexnet.classifier = new_classifier
# get the pre-trained weights of the first layer
pretrained_weights = self.alexnet.features[0].weight
new_features = nn.Sequential(*list(self.alexnet.features.children()))
new_features[0] = nn.Conv2d(4, 64, kernel_size=11, stride=4, padding=2)
# For M-channel weight should randomly initialized with Gaussian
new_features[0].weight.data.normal_(0, 0.001)
# For RGB it should be copied from pretrained weights
new_features[0].weight.data[:, :3, :, :] = nn.Parameter(pretrained_weights)
self.alexnet.features = new_features
def forward(self, images):
"""Extract Feature Vector from Input Images"""
features = self.alexnet(images)
return features
Case 2. load_state_dict() 에 strict = False로 설정.
Problem
기존 모델을 선언하고, torch.load() 해서 pretrained 해놓은 모델의 weight 값을 가져오는 과정에서 기존 모델에 맞춰줘야 weight 값이 담길 수 있는데 다르면 weight shape 에러가 나게 됩니다.
예시) 기존 모델 last linear(classifier) = 6인데, 내가 pretrained 해논 모델의 last linear 가 = 1 일때
Solution
state_load_dict() 함수 내부에 strict 옵션이 있는데 이 옵션을 false로 설정하면 같은 shape을 가지는 layer만 flexible하게 가져올 수 있다. 즉, 마지막 classifier 단의 shape이 안맞으면 이 부분은 pretrained 모델의 shape을 따르지 않고, 우리가 설정한 모델의 shape을 따른다.
model = model.cuda() # 기존의 model
pretrained_model = torch.load("...") #새로 pretrained 한 model
model.state_load_dict(pretrained_model, strict=False)
'Artificial Intelligence > Computer Vision' 카테고리의 다른 글
Optical Flow 개념 및 알고리즘 종류 (0) | 2021.09.28 |
---|---|
[의료 이미지] CT image 이해하기 (feat. Dicom 파일) (0) | 2021.06.22 |