Shubham
Shubham

Reputation: 13

PyTorch dictionary keys not matching

I am trying to implement a convolutional LSTM I found online, and it seems that the dictionary keys are not matching:

The pre-trained weights are in a pickled dictionary with the following keys:

pkl_load = torch.load(trained_model_dir)
print(pkl_load.keys())

odict_keys(['module.E.conv1.weight', 'module.E.bn1.weight', 'module.E.bn1.bias', ....

However, the keys in the state_dict for the actual NN model are:

"E.conv1.weight", "E.bn1.weight", "E.bn1.bias", ....

I am getting an error when trying to load the pre-trained weights into the state_dict because the keys don't match. What are ways to work around this? (Sorry if this is easy, I am new to PyTorch).

Upvotes: 1

Views: 1116

Answers (1)

shiv
shiv

Reputation: 177

You could do something like:

keys = ['module.E.conv1.weight', 'module.E.bn1.weight', 'module.E.bn1.bias']
res = []
for key in keys:
    words = key.split('.')
    tempRes = words[1:]
    newWord = '.'.join(tempRes)
    res.append(newWord)
print(res)

output:

['E.conv1.weight', 'E.bn1.weight', 'E.bn1.bias']

Upvotes: 0

Related Questions