Reputation: 41
I am new to pytorch and I want to use Vgg for transfer learning. I want to delete the fully connected layers and add some new fully connected layers. Also rather than RGB input I want to use grayscale input. For this I will add the weights of the input layer and get a single weight. So the three channel's weights will be added.
I achieved deleting the fully connected layers but I am having trouble with grayscale part. I add the three weights together and form a new weight. Then I try to change the state dict of the vgg model but this gives me error. The networks code is below:
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
vgg=models.vgg16(pretrained = True).features[:30]
w1=vgg.state_dict()['0.weight'][:,0,:,:] #first channel of first input layer's weight
w2=vgg.state_dict()['0.weight'][:,1,:,:]
w3=vgg.state_dict()['0.weight'][:,2,:,:]
w4=w1+w2+w3 # add the three weigths of the channels
w4=w4.unsqueeze(1) # make it 4 dimensional
a=vgg.state_dict()#create a new statedict
a['0.weight']=w4 #replace the new state dict's weigt
vgg.load_state_dict(a) # this line gives the error,load the new state dict
self.vgg =nn.Sequential(vgg)
self.fc1 = nn.Linear(14*14*512, 1000)
self.fc2 = nn.Linear(1000, 2)
def forward(self, x):
x = self.vgg(x)
x = x.view(-1, 14 * 14 * 512)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
This gives an error of:
RuntimeError: Error(s) in loading state_dict for Sequential: size mismatch for 0.weight: copying a param with shape torch.Size([64, 1, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 3, 3, 3]).
So it doesn't allow me to replace the weight with a different sized weight. Is there a solution for this problem or is there anything other that I can try. All I want to do is use the vgg's layers up to fully connected layers and change the first layers weights.
Upvotes: 3
Views: 3241
Reputation: 41
I solved the problem by including a new conv layer and initializing it's weights with a weight which is the sum of Vgg's first conv layers three channels weights. Then I excluded the first conv layer of the Vgg.
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
vgg_firstlayer=models.vgg16(pretrained = True).features[0] #load just the first conv layer
vgg=models.vgg16(pretrained = True).features[1:30] #load upto the classification layers except first conv layer
w1=vgg_firstlayer.state_dict()['weight'][:,0,:,:]
w2=vgg_firstlayer.state_dict()['weight'][:,1,:,:]
w3=vgg_firstlayer.state_dict()['weight'][:,2,:,:]
w4=w1+w2+w3 # add the three weigths of the channels
w4=w4.unsqueeze(1)# make it 4 dimensional
first_conv=nn.Conv2d(1, 64, 3, padding = (1,1)) #create a new conv layer
first_conv.weigth=torch.nn.Parameter(w4, requires_grad=True) #initialize the conv layer's weigths with w4
first_conv.bias=torch.nn.Parameter(vgg_firstlayer.state_dict()['bias'], requires_grad=True) #initialize the conv layer's weigths with vgg's first conv bias
self.first_convlayer=first_conv #the first layer is 1 channel (Grayscale) conv layer
self.vgg =nn.Sequential(vgg)
self.fc1 = nn.Linear(7*7*512, 1000)
self.fc2 = nn.Linear(1000, 2)
def forward(self, x):
x=self.first_convlayer(x)
x = self.vgg(x)
x = x.view(-1, 7 * 7 * 512)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
Upvotes: 0
Reputation: 24894
One trick performed on Kaggle for example, is to simply use one/two additional layers before VGG input extending channels to desired number (3 in this case).
More robust than modifying original model as it allows you to change the pretrained backbone easier.
Upvotes: 0
Reputation: 181
Upvotes: 1
Reputation: 3345
You haven't specified where your VGG
class comes from but I assume it's from torchvision.models
.
The VGG model is created for images with 3 channels. You can see this in the make_layers
method on GitHub.
It's probably not a good idea to modify the code within the torchvision package but you could create a copy within your project and make the in_channels
settable.
Upvotes: 1