Reputation: 3854
I have my own network in Pytorch. It first trained for the binary classifier (2 classes). After 10k epochs, I obtained the trained weight as 10000_model.pth
. Now, I want to use the model for 4 classes classifier problem using the same network. Thus, I want to transfer all trained weights in the binary classifier to 4 classes problem, without the lass layer that will random initialization. How could I do it? This is my model
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.conv_classify= nn.Conv2d(50, 2, 1, 1, bias=True) # number of class
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv_classify(x))
return x
This what I did
model = Net ()
checkpoint_dict = torch.load('10000_model.pth')
pretrained_dict = checkpoint_dict['state_dict']
model_dict = model.state_dict()
# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
model.load_state_dict(model_dict)
For now, I have to manually delete the pretrained_dict by name.
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
pretrained_dict.pop('conv_classify.weight', None)
pretrained_dict.pop('conv_classify.bias', None)
It means pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
does not do anything.
What is wrong? I am using pytorch 1.0. Thanks
Upvotes: 0
Views: 76
Reputation: 13113
Both networks have the same layers and therefore the same keys in state_dict
, so indeed
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
does nothing. The difference between the two is the weight tensors (their shape) and not their names. In other words, you can distinguish the two by [v.shape for v in model.state_dict().values()]
but not model.state_dict().keys()
. Your "workaround" approach is correct. If you want to make this a bit less manual, I would use
merged_dict = {}
for key in model_dict.keys():
if 'conv_classify' in key: # or perhaps a more complex criterion
merged_dict[key] = model_dict[key]
else:
merged_dict[key] = pretrained_dict[key]
Upvotes: 2