Jame
Jame

Reputation: 3854

How to transfer weight of own model to same network but different number of classin last layer?

I have my own network in Pytorch. It first trained for the binary classifier (2 classes). After 10k epochs, I obtained the trained weight as 10000_model.pth. Now, I want to use the model for 4 classes classifier problem using the same network. Thus, I want to transfer all trained weights in the binary classifier to 4 classes problem, without the lass layer that will random initialization. How could I do it? This is my model

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.conv_classify= nn.Conv2d(50, 2, 1, 1, bias=True) # number of class

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv_classify(x))
        return x

This what I did

model = Net ()
checkpoint_dict = torch.load('10000_model.pth')        
pretrained_dict = checkpoint_dict['state_dict']
model_dict = model.state_dict()
# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
model.load_state_dict(model_dict)

For now, I have to manually delete the pretrained_dict by name.

pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
pretrained_dict.pop('conv_classify.weight', None)
pretrained_dict.pop('conv_classify.bias', None)

It means pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} does not do anything.

What is wrong? I am using pytorch 1.0. Thanks

Upvotes: 0

Views: 76

Answers (1)

Jatentaki
Jatentaki

Reputation: 13113

Both networks have the same layers and therefore the same keys in state_dict, so indeed

pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}

does nothing. The difference between the two is the weight tensors (their shape) and not their names. In other words, you can distinguish the two by [v.shape for v in model.state_dict().values()] but not model.state_dict().keys(). Your "workaround" approach is correct. If you want to make this a bit less manual, I would use

merged_dict = {}
for key in model_dict.keys():
    if 'conv_classify' in key: # or perhaps a more complex criterion
        merged_dict[key] = model_dict[key]
    else:
        merged_dict[key] = pretrained_dict[key]

Upvotes: 2

Related Questions