sibidora
sibidora

Reputation: 41

Pytorch Grayscale input to Vgg

I am new to pytorch and I want to use Vgg for transfer learning. I want to delete the fully connected layers and add some new fully connected layers. Also rather than RGB input I want to use grayscale input. For this I will add the weights of the input layer and get a single weight. So the three channel's weights will be added.

I achieved deleting the fully connected layers but I am having trouble with grayscale part. I add the three weights together and form a new weight. Then I try to change the state dict of the vgg model but this gives me error. The networks code is below:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
    vgg=models.vgg16(pretrained = True).features[:30]

    w1=vgg.state_dict()['0.weight'][:,0,:,:] #first channel of first input layer's weight
    w2=vgg.state_dict()['0.weight'][:,1,:,:]
    w3=vgg.state_dict()['0.weight'][:,2,:,:]
    w4=w1+w2+w3 # add the three weigths of the channels
    w4=w4.unsqueeze(1) # make it 4 dimensional

    a=vgg.state_dict()#create a new statedict
    a['0.weight']=w4 #replace the new state dict's weigt

    vgg.load_state_dict(a) # this line gives the error,load the new state dict

    self.vgg =nn.Sequential(vgg)
    self.fc1 = nn.Linear(14*14*512, 1000)
    self.fc2 = nn.Linear(1000, 2)

def forward(self, x):
    x = self.vgg(x)
    x = x.view(-1, 14 * 14 * 512)
    x = F.relu(self.fc1(x))
    x = self.fc2(x)
    return x

This gives an error of:

RuntimeError: Error(s) in loading state_dict for Sequential: size mismatch for 0.weight: copying a param with shape torch.Size([64, 1, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 3, 3, 3]).

So it doesn't allow me to replace the weight with a different sized weight. Is there a solution for this problem or is there anything other that I can try. All I want to do is use the vgg's layers up to fully connected layers and change the first layers weights.

Upvotes: 3

Views: 3241

Answers (4)

sibidora
sibidora

Reputation: 41

I solved the problem by including a new conv layer and initializing it's weights with a weight which is the sum of Vgg's first conv layers three channels weights. Then I excluded the first conv layer of the Vgg.

class Net(nn.Module):
def __init__(self):
    super(Net, self).__init__()

    vgg_firstlayer=models.vgg16(pretrained = True).features[0] #load just the first conv layer
    vgg=models.vgg16(pretrained = True).features[1:30] #load upto the classification layers except first conv layer

    w1=vgg_firstlayer.state_dict()['weight'][:,0,:,:]
    w2=vgg_firstlayer.state_dict()['weight'][:,1,:,:]
    w3=vgg_firstlayer.state_dict()['weight'][:,2,:,:]
    w4=w1+w2+w3 # add the three weigths of the channels
    w4=w4.unsqueeze(1)# make it 4 dimensional


    first_conv=nn.Conv2d(1, 64, 3, padding = (1,1)) #create a new conv layer
    first_conv.weigth=torch.nn.Parameter(w4, requires_grad=True) #initialize  the conv layer's weigths with w4
    first_conv.bias=torch.nn.Parameter(vgg_firstlayer.state_dict()['bias'], requires_grad=True) #initialize  the conv layer's weigths with vgg's first conv bias


    self.first_convlayer=first_conv #the first layer is 1 channel (Grayscale) conv  layer
    self.vgg =nn.Sequential(vgg)

    self.fc1 = nn.Linear(7*7*512, 1000)
    self.fc2 = nn.Linear(1000, 2)

def forward(self, x):
    x=self.first_convlayer(x)
    x = self.vgg(x)

    x = x.view(-1, 7 * 7 * 512)
    x = F.relu(self.fc1(x))

    x = self.fc2(x)
    return x

Upvotes: 0

Szymon Maszke
Szymon Maszke

Reputation: 24894

One trick performed on Kaggle for example, is to simply use one/two additional layers before VGG input extending channels to desired number (3 in this case).

More robust than modifying original model as it allows you to change the pretrained backbone easier.

Upvotes: 0

Yixuan Wei
Yixuan Wei

Reputation: 181

  • In short: The Error is caused by Mismatch between pretrained model parameters and the vgg model
  • Reason: You modified the parameters in pretrained model from [64,3,3,3] -> [64,1,3,3] by adding, but you didn't change the structure of VGG, which still needs a [64,3,3,3] shape of input.
  • Resolution: Remove the first convolution layer of VGG structure and add a new one which makes it to fit you input shape.

Upvotes: 1

Florian Blume
Florian Blume

Reputation: 3345

You haven't specified where your VGG class comes from but I assume it's from torchvision.models.

The VGG model is created for images with 3 channels. You can see this in the make_layers method on GitHub.

It's probably not a good idea to modify the code within the torchvision package but you could create a copy within your project and make the in_channels settable.

Upvotes: 1

Related Questions