Reputation: 485
I am migrating from Keras/TF frameworks and I have litte troubles understanding the transfer learning process in PyTorch.
I want to use pytorch-lightning framework and I want to switch between different neural networks in one script.
Per this example we can switch between different neural networks in their implementation:
class BERT(pl.LightningModule):
def __init__(self, model_name, task):
self.task = task
if model_name == 'transformer':
self.net = Transformer()
elif model_name == 'my_cool_version':
self.net = MyCoolVersion()
The question is: how to create a new neural network that extends the nn.Module and utilizes transfer learning process?
My own implementation looks like this: I am using vgg16 network and replaced the classifier layer with only one fc with two output neurons.
class VGGNetwork(nn.Module):
def __init__(self):
super(VGGNetwork, self).__init__()
# vgg16 is the default model here, we can use bn etc...
self.model = vgg16(pretrained=True)
# removing the last three layers of classifier only 2 ...
self.model.classifier = nn.Sequential(nn.Linear(512 * 7 * 7, 2))
def forward(self, x):
return self.model.forward(x)
Is this the correct way how to do that?
Upvotes: 2
Views: 813
Reputation: 173
https://pytorch-lightning.readthedocs.io/en/0.7.1/transfer_learning.html
...
class AutoEncoder(pl.LightningModule):
def __init__(self):
self.encoder = Encoder()
self.decoder = Decoder()
class CIFAR10Classifier(pl.LightingModule):
def __init__(self):
# init the pretrained LightningModule
self.feature_extractor = AutoEncoder.load_from_checkpoint(PATH)
self.feature_extractor.freeze()
# the autoencoder outputs a 100-dim representation and CIFAR-10 has 10 classes
self.classifier = nn.Linear(100, 10)
def forward(self, x):
representations = self.feature_extractor(x)
x = self.classifier(representations)
...
Upvotes: 0
Reputation: 602
you can freeze weights and bais for the neural network layer except for the last layer.
you can use requires_grad = False
for param in model_conv.parameters():
param.requires_grad = False
you can find more about this at the following link https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html
Upvotes: 2