Tom Hale
Tom Hale

Reputation: 46983

PyTorch transfer learning with pre-trained ImageNet model

I want to create an image classifier using transfer learning on a model already trained on ImageNet.

How do I replace the final layer of a torchvision.models ImageNet classifier with my own custom classifier?

Upvotes: 1

Views: 1793

Answers (1)

Tom Hale
Tom Hale

Reputation: 46983

Get a pre-trained ImageNet model (resnet152 has the best accuracy):

from torchvision import models
# https://pytorch.org/docs/stable/torchvision/models.html
model = models.resnet152(pretrained=True)

Print out its structure so we can compare to the final state:

print(model)

Remove the last module (generally a single fully connected layer) from model:

classifier_name, old_classifier = model._modules.popitem()

Freeze the parameters of the feature detector part of the model so that they are not adjusted by back-propagation:

for param in model.parameters():
    param.requires_grad = False

Create a new classifier:

classifier_input_size = old_classifier.in_features

classifier = nn.Sequential(OrderedDict([
                           ('fc1', nn.Linear(classifier_input_size, hidden_layer_size)),
                           ('activation', nn.SELU()),
                           ('dropout', nn.Dropout(p=0.5)),
                           ('fc2', nn.Linear(hidden_layer_size, output_layer_size)),
                           ('output', nn.LogSoftmax(dim=1))
                           ]))

The module name for our classifier needs to be the same as the one which was removed. Add our new classifier to the end of the feature detector:

model.add_module(classifier_name, classifier)

Finally, print out the structure of the new network:

print(model)

Upvotes: 5

Related Questions