I have been following the ants and bees transfer learning tutorial from the official PyTorch Docs ( I am trying to finetune a VGG19 model by changing the final layer to predict one of two classes. I am able to modify the last fc layer using the following code.
But I get an error when executing the train_model function. The error is “size mismatch at /opt/conda/conda-bld/pytorch_1513368888240/work/torch/lib/THC/generic/”. Any idea what the issue is ?
model_conv = torchvision.models.vgg19(pretrained=True)
for param in model_conv.parameters():
param.requires_grad = False
model_conv = nn.Sequential(*list(model_conv.classifier.children())[:-1] +
[nn.Linear(in_features=4096, out_features=2)])
if use_gpu:
model_conv = model_conv.cuda()
criterion = nn.CrossEntropyLoss()
optimizer_conv = optim.SGD(model_conv._modules['6'].parameters(), lr=0.001, momentum=0.9)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)
model_conv = train_model(model_conv, criterion, optimizer_conv, exp_lr_scheduler, num_epochs=25)
Upvotes: 0
Views: 650
Reputation: 2751
When you are defining your model you are just considering the classifier
which consists on the fully connected part of the network only. Then, when feeding the 224*224*3 image to the model it tries to "go through" a linear layer with 25K features as the input. To solve it you just need to add the convolutional part before, to do so redefine the model like this:
class newModel(nn.Module):
def __init__(self, old_model):
super(newModel, self).__init__()
self.features = old_model.features
self.classifier = nn.Sequential(*list(old_model.classifier.children())[:-1] +
[nn.Linear(in_features=4096, out_features=2)])
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
model_conv = newModel(model_conv)
Now you just also tell the parameters to optimize, if you just want to train the last layer (the one that is newly added) do :
optimizer_conv = optim.SGD(model_conv.classifier._modules['6'].parameters(), lr=0.001, momentum=0.9)
The rest of the code remains the same.
Hope it helps!
Upvotes: 0