Reputation: 7952
This is what I did:
list(tmp.state_dict().keys())[-1].split('.')[0]
What is the proper way? My goal is to replace the last layer for the purpose of transfer learning.
Upvotes: 0
Views: 1450
Reputation: 8699
You can simple follow these steps to get the last layer from a pretrained pytorch model:
Finally, use the PyTorch function nn.Sequential() to stack this modified list together into a new model.
nn.Sequential(*list(model.children())[:-1])
You can read more about this from here.
Upvotes: 3