Zeeshan Anjum
Zeeshan Anjum

Reputation: 450

Load pytorch model from 0.4.1 to 0.4.0?

I trained a DENSENET161 model using pytorch 0.4.1 (GPU) and on testing environment I have to load it in pytorch version 0.4.0 (CPU). I am already using model.cpu() but when I am loading static dictionary model.load_state_dict(checkpoint['state_dict'])

I am getting following error:

RuntimeError: Error(s) in loading state_dict for DenseNet: Unexpected key(s) in state_dict: "features.norm0.num_batches_tracked", "features.denseblock1.denselayer1.norm1.num_batches_tracked", "features.denseblock1.denselayer1.norm2.num_batches_tracked", "features.denseblock1.denselayer2.norm1.num_batches_tracked",...

Upvotes: 1

Views: 791

Answers (1)

Jatentaki
Jatentaki

Reputation: 13113

It seems to stem from the difference in implementation of normalization layers between PyTorch 0.4.1 and 0.4 - the former tracks some state variable called num_batches_tracked, which pytorch 0.4 does not expect. Assuming there are only unexpected keys and no missing keys (which I can't tell for sure since you've clipped the error message), you can just delete the extraneous ones and hopefully the model will load. Therefore try

model_dict = checkpoint['state_dict']
filtered = {
    k: v for k, v in model_dict.items() if 'num_batches_tracked' not in k
}
model.load_state_dict(filtered)

Please note, there may have been changes to the internals of normalization other than just what you're seeing here, so even if this fix suppresses the exception, the model may still silently misbehave.

Upvotes: 1

Related Questions