Reputation: 77
I want to retrain a custom model with my small dataset. I can load the pretrained weight (.pth) and run it in Pytorch. However, I need more functionalities and refactored the code to Pytorch lightning but I can't figure out how to load the pretrained weight into the Pytorch Lightning model.
Please see the details of my code below:
class BDRAR(nn.Module):
def __init__(self):
super(BDRAR, self).__init__()
resnext = ResNeXt101()
self.layer0 = resnext.layer0
self.layer1 = resnext.layer1
self.layer2 = resnext.layer2
self.layer3 = resnext.layer3
self.layer4 = resnext.layer4
Pytorch Lightning code:
class liteBDRAR(pl.LightningModule):
def __init__(self):
super(liteBDRAR, self).__init__()
self.model = BDRAR()
print('Model Created!')
def forward(self, x):
return self.model(x)
Pytorch Lightning run:
path = './ckpt/BDRAR/3000.pth'
bdrar = liteBDRAR.load_from_checkpoint(path, strict=False)
trainer = pl.Trainer(fast_dev_run=True, gpus=1)
trainer.fit(bdrar)
Error:
keys = model.load_state_dict(checkpoint["state_dict"], strict=strict)
**KeyError: 'state_dict'**
I will appreciate any help.
Thank you.
Upvotes: 1
Views: 5388
Reputation: 1392
It can be that your .pth
file is already a state_dict
. Try to load pretrained weight in your lightning class.
class liteBDRAR(pl.LightningModule):
def __init__(self):
super(liteBDRAR, self).__init__()
self.model = BDRAR()
print('Model Created!')
def load_model(self, path):
self.model.load_state_dict(torch.load(path, map_location='cuda:0'), strict=False)
path = './ckpt/BDRAR/3000.pth'
model = liteBDRAR()
model.load_model(path)
Upvotes: 2
Reputation: 11
Those pretrained weights belong to class BDRAR(nn.Module)
. That is, the class in your lightningmodule's model
param.
The LightningModule liteBDRAR()
is acting as a wrapper to your Pytorch model (located at self.model
). You need to load the weights onto the pytorch model inside your lightningmodule.
As @Jules and @Dharman mentioned, what you need is:
path = './ckpt/BDRAR/3000.pth'
bdrar = liteBDRAR()
bdrar.model.load_state_dict(torch.load(path))
Upvotes: 0
Reputation: 445
The reason why you're getting this error is because you are trying to load your PyTorch's model weights into the Lightning module. When saving checkpoints with Lightning you don't only save the model states but also a bunch of other info (see here).
What you are looking for is the following:
path = './ckpt/BDRAR/3000.pth'
bdrar = liteBDRAR()
bdrar.model.load_state_dict(torch.load(path))
Upvotes: 2