crasse
crasse

Reputation: 11

loading *.pth checkpoint locally with pytorch

I'm trying to load offline a VGG19 checkpoint from a local file instead of the regular pytorch method (which download it online) and got problems. so basicly i'm doing this : https://pytorch.org/tutorials/advanced/neural_style_tutorial.html

and instead of

cnn = models.vgg19(pretrained=True).features.to(device).eval()

Which works well with the rest, I want to work from a local *.pth file (the same, 'vgg19-dcbb9e9d.pth', put in specific folder) then I tried using this method :

checkpoint = torch.load('models/vgg19-dcbb9e9d.pth')
cnn = models.vgg19()
cnn.load_state_dict(checkpoint)
cnn.eval()

but then got an error

---> 32             raise RuntimeError('Unrecognized layer: {}'.format(layer.__class__.__name__))
     33 
     34         model.add_module(name, layer)

RuntimeError: Unrecognized layer: Sequential

basicly the model wasn't loaded or read correctly as it seems it didn't find the layers the code is looking for. Is there something I'm missing ?

Upvotes: 0

Views: 749

Answers (1)

Sezilber Se
Sezilber Se

Reputation: 1

Perhaps the classifier layers are not needed.

Сheck both:

print("Model's state_dict:")
for param_tensor in cnn.state_dict():
    print(param_tensor, "\t", cnn.state_dict()[param_tensor].size())

If you need only features, then

model = copy.deepcopy(cnn.features)
model.to(device)

for param in model.parameters():
    param.requires_grad = False

Upvotes: 0

Related Questions