Reputation: 13
I am trying to implement a convolutional LSTM I found online, and it seems that the dictionary keys are not matching:
The pre-trained weights are in a pickled dictionary with the following keys:
pkl_load = torch.load(trained_model_dir)
print(pkl_load.keys())
odict_keys(['module.E.conv1.weight', 'module.E.bn1.weight', 'module.E.bn1.bias', ....
However, the keys in the state_dict for the actual NN model are:
"E.conv1.weight", "E.bn1.weight", "E.bn1.bias", ....
I am getting an error when trying to load the pre-trained weights into the state_dict because the keys don't match. What are ways to work around this? (Sorry if this is easy, I am new to PyTorch).
Upvotes: 1
Views: 1116
Reputation: 177
You could do something like:
keys = ['module.E.conv1.weight', 'module.E.bn1.weight', 'module.E.bn1.bias']
res = []
for key in keys:
words = key.split('.')
tempRes = words[1:]
newWord = '.'.join(tempRes)
res.append(newWord)
print(res)
output:
['E.conv1.weight', 'E.bn1.weight', 'E.bn1.bias']
Upvotes: 0