AturSams
AturSams

Reputation: 7952

Getting the last layer from a pretrained pytorch for transfer learning?

This is what I did:

list(tmp.state_dict().keys())[-1].split('.')[0]

What is the proper way? My goal is to replace the last layer for the purpose of transfer learning.

Upvotes: 0

Views: 1450

Answers (1)

Anubhav Singh
Anubhav Singh

Reputation: 8699

You can simple follow these steps to get the last layer from a pretrained pytorch model:

  • We can get the layers by using model.children().
  • Convert this into a list by using a list() command on it.
  • Remove the last layer by indexing the list.
  • Finally, use the PyTorch function nn.Sequential() to stack this modified list together into a new model.

    nn.Sequential(*list(model.children())[:-1])

You can read more about this from here.

Upvotes: 3

Related Questions