Reputation: 11134
I am trying to use transfer learning for an image segmentation task, and my plan is to use the first few layers of a pretrained model (VGG16 for example) as an encoder and then will add my own decoder.
So, I can load the model and see the structure by printing it:
model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet18', pretrained=True)
print(model)
I get like this:
ResNet(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(layer1): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(1): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
.....
.....
.....
I can also access the specific layers with model.layer3
for example. Now, I am struggling with certain things.
freeze
only this stripped part, and keep the newly added modules available for training?Upvotes: 0
Views: 4455
Reputation: 1127
For 1): Initialize the ResNet in your LightningModule
and slice it until the part that you need. Then add your own head after that, and define forward
in the order that you need. See this example, based on the transfer learning docs:
import torchvision.models as models
class ImagenetTransferLearning(LightningModule):
def __init__(self):
super().__init__()
# init a pretrained resnet
backbone_tmp = models.resnet50(pretrained=True)
num_filters = backbone_tmp.fc.in_features
layers = list(backbone_tmp.children())[:-1]
self.backbone = nn.Sequential(*layers)
# use the pretrained model to classify cifar-10 (10 image classes)
num_target_classes = 10
self.classifier = nn.Linear(num_filters, num_target_classes)
For 2): Pass a BackboneFinetuning
callback to your trainer
. This requires that your LightningModule
has a self.backbone
attribute containing the modules that you want to be frozen, as shown on the snippet above. You can also use the BaseFinetuning
callback if you need different freeze-unfreeze behavior.
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import BackboneFinetuning
multiplicative = lambda epoch: 1.5
backbone_finetuning = BackboneFinetuning(200, multiplicative)
trainer = Trainer(callbacks=[backbone_finetuning])
Upvotes: 3
Reputation: 40618
The following is true for any child module of model
, but I will answer your question with model.layer3
here:
model.layer3
will give you the nn.Module
associated with layer n°3 of your model. You can call it directly as you would with model
>>> z = model.layer3(torch.rand(16, 128, 10, 10))
>>> z.shape
torch.Size([16, 256, 5, 5])
To freeze the model:
you could put the layer in eval mode which disables dropouts and makes BN layers use statistics learning during training. This is done with model.layer3.eval()
you must disable training on that layer by toggling the requires_grad
flag: model.layer3.requires_grad_(False)
, this will affect all child parameters.
Upvotes: 1