Gaegul's devlog

pretrained model layer 수정하기 본문

Artificial Intelligence/Computer Vision

pretrained model layer 수정하기

부지런깨꾹이 2021. 7. 28. 16:07
728x90
반응형

Case 1. pretrained CNN model layer 직접 수정

pretrained CNN model layer 직접 수정하여 input channel을 변경해 보도록 하겠습니다.

Problem

일반적인 classification task를 위한 cnn 기반 모델의 인풋은 3 channel (RGB)로 들어가게 됩니다.  하지만 데이터를 가공시키고 원하는 input channel 이 3채널이 아닐때(4채널 or 그 이상 채널이 들어가야 할 때) pretrained cnn 모델을 가져와서 사용하고 싶을때 본 방법은 유용합니다. 

저는 se_resnext101 네트워크에 ImageNet dataset이 pretrained 모델을 가져와서 사용하였습니다.

그리고 제가 가공한 이미지 데이터는 9*256*256 였기에 첫번째 conv layer 에서 input channel 을 9로 변경해 주었습니다!

 

두 가지 방법을 공유 해드리려고 합니다!

 

Solution.1 (제가 사용한 코드)

원하는 conv layer 만 수정하는 방법 (layer 수정, weight 수정)

class se_resnext101_32x4d(nn.Module):
    def __init__(self):
        super(se_resnext101_32x4d, self).__init__()
        self.model_ft = pretrainedmodels.__dict__['se_resnext101_32x4d'](num_classes=1000, pretrained='imagenet')
		
        # input channel 변경
        prev_w = self.model_ft.layer0.conv1.weight #기존 pretrained 된 모델 weight
        self.model_ft.layer0.conv1 = nn.Conv2d(9, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) # 새로 정의하는 conv   
        self.model_ft.layer0.conv1.weight = nn.Parameter(torch.cat((prev_w, torch.zeros(64, 6, 7, 7)), dim=1)) # 새로 정의하는 weight
        
        num_ftrs = self.model_ft.last_linear.in_features
        self.model_ft.avg_pool = nn.AdaptiveAvgPool2d((1,1))
        self.model_ft.last_linear = nn.Sequential(nn.Linear(num_ftrs, 6, bias=True))

    def forward(self, x):
        x = self.model_ft(x)
        return x

 

이 세 줄만  추가하여 기존 pretrained 모델 로드하고 기존의 weight 와 바뀔 weight의 shape을 맞춰서 사용하시면 됩니다!

prev_w = self.model_ft.layer0.conv1.weight #기존 pretrained 된 모델 weight
self.model_ft.layer0.conv1 = nn.Conv2d(9, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) 
self.model_ft.layer0.conv1.weight = nn.Parameter(torch.cat((prev_w, torch.zeros(64, 6, 7, 7)), dim=1)) 

 

Solution. 2 

feautre를 list로 받아와서 수정하는 방법.

출처 : https://discuss.pytorch.org/t/how-to-change-no-of-input-channels-to-a-pretrained-model/19379/6

class RGBMaskEncoderCNN(nn.Module):
	def __init__(self):
		super(RGBMaskEncoderCNN, self).__init__()

		self.alexnet = models.alexnet(pretrained=True)
		new_classifier = nn.Sequential(*list(self.alexnet.classifier.children())[:-1])
		self.alexnet.classifier = new_classifier

		# get the pre-trained weights of the first layer
		pretrained_weights = self.alexnet.features[0].weight
		new_features = nn.Sequential(*list(self.alexnet.features.children()))
		new_features[0] = nn.Conv2d(4, 64, kernel_size=11, stride=4, padding=2)
		# For M-channel weight should randomly initialized with Gaussian
		new_features[0].weight.data.normal_(0, 0.001)
		# For RGB it should be copied from pretrained weights
		new_features[0].weight.data[:, :3, :, :] = nn.Parameter(pretrained_weights)
		self.alexnet.features = new_features

	def forward(self, images):
		"""Extract Feature Vector from Input Images"""
		features = self.alexnet(images)
		return features

 

Case 2. load_state_dict() 에 strict = False로 설정.

 

Problem 

기존 모델을 선언하고, torch.load() 해서 pretrained 해놓은 모델의 weight 값을 가져오는 과정에서 기존 모델에 맞춰줘야 weight 값이 담길 수 있는데 다르면 weight shape 에러가 나게 됩니다.  

예시) 기존 모델 last linear(classifier) = 6인데, 내가 pretrained 해논 모델의 last linear 가 = 1 일때

 

Solution 

state_load_dict() 함수 내부에 strict 옵션이 있는데 이 옵션을 false로 설정하면 같은 shape을 가지는 layer만 flexible하게 가져올 수 있다. 즉, 마지막 classifier 단의 shape이 안맞으면 이 부분은 pretrained 모델의 shape을 따르지 않고, 우리가 설정한 모델의 shape을 따른다.

model = model.cuda() # 기존의 model
pretrained_model = torch.load("...") #새로 pretrained 한 model
model.state_load_dict(pretrained_model, strict=False)
728x90
반응형
Comments