dqmis
dqmis

Reputation: 489

Finetune PyTorch model after training fc layers

I am trying to do Transfer Learning using PyTorch. I wan to train fc layers first and then to finetune the whole network. Unfortunately after training fc layers and then passing my network to finetune, I am losing the accuracy that was acquired in the first training. Is this an expected behaviour or am I doing something wrong here?

Here is the code:

model = torchvision.models.resnet50(pretrained=True)
for param in model.parameters():
    param.requires_grad = False

num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 4)
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

model = trainer.fit_model(dataloader, model, criterion, optimizer, num_epochs=10)
# fit model is basic PyTorch training function found here: https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html#convnet-as-fixed-feature-extractor The only difference is that scheduler is an optional param.

for param in model.parameters():
    param.requires_grad = True

torch.cuda.empty_cache()
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

# Here I am finetuning the model
model_ft = trainer.fit_model(
    dataloader, model, criterion, optimizer, scheduler=exp_lr_scheduler, num_epochs=10
)

Am I missing something here or should I just train the model once?

Upvotes: 2

Views: 491

Answers (1)

Statistic Dean
Statistic Dean

Reputation: 5270

That is something that can happen when performing transfer learning called catastrophic forgetting. Basically, you update your pretrained weights too much and you 'forget' what was previously learned. This can happen notably if your learning rate is too high. I would suggest trying at first a lower learning rate, or using diffentiable learning rate (different learning rate for the head of the network and the pretrained part, so that you can have a higher learning rate on the fc layers than for the rest of the network).

Upvotes: 3

Related Questions