Reputation: 5550
In Pytorch, we load the pretrained model as follows:
net.load_state_dict(torch.load(path)['model_state_dict'])
Then the network structure and the loaded model have to be exactly the same. However, is it possible to load the weights but then modify the network/add an extra parameter?
Note: If we add an extra parameter to the model earlier before loading the weights, e.g.
self.parameter = Parameter(torch.ones(5),requires_grad=True)
we will get Missing key(s) in state_dict:
error when loading the weights.
Upvotes: 2
Views: 1796
Reputation: 37771
Let's create a model and save its' state.
class Model1(nn.Module):
def __init__(self):
super(Model1, self).__init__()
self.encoder = nn.LSTM(100, 50)
def forward(self):
pass
model1 = Model1()
torch.save(model1.state_dict(), 'filename.pt') # saving model
Then create a second model which has a few layers common to the first model. Load the states of the first model and load it to the common layers of the second model.
class Model2(nn.Module):
def __init__(self):
super(Model2, self).__init__()
self.encoder = nn.LSTM(100, 50)
self.linear = nn.Linear(50, 200)
def forward(self):
pass
model1_dict = torch.load('filename.pt')
model2 = Model2()
model2_dict = model2.state_dict()
# 1. filter out unnecessary keys
filtered_dict = {k: v for k, v in model1_dict.items() if k in model2_dict}
# 2. overwrite entries in the existing state dict
model2_dict.update(filtered_dict)
# 3. load the new state dict
model2.load_state_dict(model2_dict)
Upvotes: 4