John Duke
John Duke

Reputation: 77

Unable to load custom pretrained weight in Pytorch Lightning

I want to retrain a custom model with my small dataset. I can load the pretrained weight (.pth) and run it in Pytorch. However, I need more functionalities and refactored the code to Pytorch lightning but I can't figure out how to load the pretrained weight into the Pytorch Lightning model.

Please see the details of my code below:

class BDRAR(nn.Module):
    def __init__(self):
        super(BDRAR, self).__init__()
        resnext = ResNeXt101()
        self.layer0 = resnext.layer0
        self.layer1 = resnext.layer1
        self.layer2 = resnext.layer2
        self.layer3 = resnext.layer3
        self.layer4 = resnext.layer4

Pytorch Lightning code:

class liteBDRAR(pl.LightningModule):
    def __init__(self):
        super(liteBDRAR, self).__init__()
        self.model = BDRAR()
        print('Model Created!')

    def forward(self, x):
        return self.model(x)

Pytorch Lightning run:

    path = './ckpt/BDRAR/3000.pth'
    bdrar = liteBDRAR.load_from_checkpoint(path,  strict=False)
    trainer = pl.Trainer(fast_dev_run=True, gpus=1)
    trainer.fit(bdrar)

Error:

keys = model.load_state_dict(checkpoint["state_dict"], strict=strict)
**KeyError: 'state_dict'**

I will appreciate any help.

Thank you.

Upvotes: 1

Views: 5388

Answers (3)

joe32140
joe32140

Reputation: 1392

It can be that your .pth file is already a state_dict. Try to load pretrained weight in your lightning class.

class liteBDRAR(pl.LightningModule):
    def __init__(self):
        super(liteBDRAR, self).__init__()
        self.model = BDRAR()
        print('Model Created!')

    def load_model(self, path):
        self.model.load_state_dict(torch.load(path, map_location='cuda:0'), strict=False)

path = './ckpt/BDRAR/3000.pth'
model = liteBDRAR()
model.load_model(path)

Upvotes: 2

Sarita Hedaya
Sarita Hedaya

Reputation: 11

Those pretrained weights belong to class BDRAR(nn.Module). That is, the class in your lightningmodule's model param.

The LightningModule liteBDRAR() is acting as a wrapper to your Pytorch model (located at self.model). You need to load the weights onto the pytorch model inside your lightningmodule. As @Jules and @Dharman mentioned, what you need is:

path = './ckpt/BDRAR/3000.pth'
bdrar = liteBDRAR()
bdrar.model.load_state_dict(torch.load(path))

Upvotes: 0

Jules
Jules

Reputation: 445

The reason why you're getting this error is because you are trying to load your PyTorch's model weights into the Lightning module. When saving checkpoints with Lightning you don't only save the model states but also a bunch of other info (see here).

What you are looking for is the following:

path = './ckpt/BDRAR/3000.pth'
bdrar = liteBDRAR()
bdrar.model.load_state_dict(torch.load(path))

Upvotes: 2

Related Questions