Ahasanul Haque
Ahasanul Haque

Reputation: 11134

How to strip a pretrained network and add some layers to it using pytorch lightning?

I am trying to use transfer learning for an image segmentation task, and my plan is to use the first few layers of a pretrained model (VGG16 for example) as an encoder and then will add my own decoder.

So, I can load the model and see the structure by printing it:

model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet18', pretrained=True)
print(model)

I get like this:

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  .....
  .....
  .....

I can also access the specific layers with model.layer3 for example. Now, I am struggling with certain things.

  1. How to cut the model and take every module from the beginning to the end of any layer (model.layer3 for example)?
  2. How to freeze only this stripped part, and keep the newly added modules available for training?

Upvotes: 0

Views: 4455

Answers (3)

drgxfs
drgxfs

Reputation: 1127

For 1): Initialize the ResNet in your LightningModule and slice it until the part that you need. Then add your own head after that, and define forward in the order that you need. See this example, based on the transfer learning docs:

import torchvision.models as models

class ImagenetTransferLearning(LightningModule):
    def __init__(self):
        super().__init__()

        # init a pretrained resnet
        backbone_tmp = models.resnet50(pretrained=True)
        num_filters = backbone_tmp.fc.in_features
        layers = list(backbone_tmp.children())[:-1]
        self.backbone = nn.Sequential(*layers)

        # use the pretrained model to classify cifar-10 (10 image classes)
        num_target_classes = 10
        self.classifier = nn.Linear(num_filters, num_target_classes)

For 2): Pass a BackboneFinetuning callback to your trainer. This requires that your LightningModule has a self.backbone attribute containing the modules that you want to be frozen, as shown on the snippet above. You can also use the BaseFinetuning callback if you need different freeze-unfreeze behavior.

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import BackboneFinetuning
multiplicative = lambda epoch: 1.5
backbone_finetuning = BackboneFinetuning(200, multiplicative)
trainer = Trainer(callbacks=[backbone_finetuning])

Upvotes: 3

Ivan
Ivan

Reputation: 40618

The following is true for any child module of model, but I will answer your question with model.layer3 here:

  1. model.layer3 will give you the nn.Module associated with layer n°3 of your model. You can call it directly as you would with model

    >>> z = model.layer3(torch.rand(16, 128, 10, 10))
    >>> z.shape
    torch.Size([16, 256, 5, 5])
    
  2. To freeze the model:

    • you could put the layer in eval mode which disables dropouts and makes BN layers use statistics learning during training. This is done with model.layer3.eval()

    • you must disable training on that layer by toggling the requires_grad flag: model.layer3.requires_grad_(False), this will affect all child parameters.

Upvotes: 1

Related Questions