Constant Ramblings
Constant Ramblings

Reputation: 31

Missing Keys in state_dict

I am having problems loading my model on google colab. here is the code:

I have attached the code below

I have tried changing the name of the statedict and it does not help basically, I am trying to save my model for later use, but, this is becoming extremely difficult since I am not being able to properly save and load it. Please help me with the problem. After the section of the code, you will also find the error that I have attached below.

here is the code

from zipfile import ZipFile
file_name = 'data.zip'
with ZipFile(file_name, 'r') as zip:
  zip.extractall()

from zipfile import ZipFile
file_name = 'results.zip'
with ZipFile(file_name, 'r') as zip:
  zip.extractall()

!pip install tensorflow-gpu

from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable


batchSize = 64 
imageSize = 64 

transform = transforms.Compose([transforms.Resize(imageSize), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),]) 

dataset = dset.CIFAR10(root = './data', download = True, transform = transform) 
dataloader = torch.utils.data.DataLoader(dataset, batch_size = batchSize, shuffle = True, num_workers = 2) 


def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


class G(nn.Module):

    def __init__(self):
        super(G, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(100, 512, 4, 1, 0, bias = False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias = False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias = False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias = False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 3, 4, 2, 1, bias = False),
            nn.Tanh()
        )

    def forward(self, input):
        output = self.main(input)
        return output


netG = G()
netG.load_state_dict(torch.load('generator.pth'))
netG.eval()
#netG.apply(weights_init)



class D(nn.Module):

    def __init__(self):
        super(D, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1, bias = False),
            nn.LeakyReLU(0.2, inplace = True),
            nn.Conv2d(64, 128, 4, 2, 1, bias = False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace = True),
            nn.Conv2d(128, 256, 4, 2, 1, bias = False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace = True),
            nn.Conv2d(256, 512, 4, 2, 1, bias = False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace = True),
            nn.Conv2d(512, 1, 4, 1, 0, bias = False),
            nn.Sigmoid()
        )

    def forward(self, input):
        output = self.main(input)
        return output.view(-1)


netD = D()
netD.load_state_dict(torch.load('discriminator.pth'))
netD.eval()
#netD.apply(weights_init)


criterion = nn.BCELoss()
checkpoint = torch.load('discriminator.pth')
optimizerD = optim.Adam(netD.parameters(), lr = 0.0002, betas = (0.5, 0.999))
optimizerD.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
errD = checkpoint['loss']
checkpoint1 = torch.load('genrator.pth')
optimizerG = optim.Adam(netG.parameters(), lr = 0.0002, betas = (0.5, 0.999))
optimizerG.load_state_dict(checkpoint1['optimizer_state_dict'])
errG = checkpoint1['loss']
k = epoch
for j in range(k, 10):

    for i, data in enumerate(dataloader, 0):


        netD.zero_grad()


        real, _ = data
        input = Variable(real)
        target = Variable(torch.ones(input.size()[0]))
        output = netD(input)
        errD_real = criterion(output, target)


        noise = Variable(torch.randn(input.size()[0], 100, 1, 1))
        fake = netG(noise)
        target = Variable(torch.zeros(input.size()[0]))
        output = netD(fake.detach())
        errD_fake = criterion(output, target)


        errD = errD_real + errD_fake
        errD.backward()
        optimizerD.step()



        netG.zero_grad()
        target = Variable(torch.ones(input.size()[0]))
        output = netD(fake)
        errG = criterion(output, target)
        errG.backward()
        optimizerG.step()



        print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f' % (epoch+1, 10, i+1, len(dataloader), errD.data, errG.data))
        if i % 100 == 0:
            vutils.save_image(real, '%s/real_samples.png' % "./results", normalize = True)
            fake = netG(noise)
            vutils.save_image(fake.data, '%s/fake_samples_epoch_%03d.png' % ("./results", epoch+1), normalize = True)

torch.save({
            'epoch': epoch,
            'model_state_dict': netD.state_dict(),
            'optimizer_state_dict': optimizerD.state_dict(),
            'loss': errD
            }, 'discriminator.pth')
torch.save({
            'epoch': epoch,
            'model_state_dict': netG.state_dict(),
            'optimizer_state_dict': optimizerG.state_dict(),
            'loss': errG
            }, 'generator.pth')

here is the error

RuntimeError                              Traceback (most recent call last)
<ipython-input-23-3e55546152c7> in <module>()
     26 # Creating the generator
     27 netG = G()
---> 28 netG.load_state_dict(torch.load('generator.pth'))
     29 netG.eval()
     30 #netG.apply(weights_init)

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
    767         if len(error_msgs) > 0:
    768             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
--> 769                                self.__class__.__name__, "\n\t".join(error_msgs)))
    770 
    771     def _named_members(self, get_members_fn, prefix='', recurse=True):

RuntimeError: Error(s) in loading state_dict for G:
    Missing key(s) in state_dict: "main.0.weight", "main.1.weight", "main.1.bias", "main.1.running_mean", "main.1.running_var", "main.3.weight", "main.4.weight", "main.4.bias", "main.4.running_mean", "main.4.running_var", "main.6.weight", "main.7.weight", "main.7.bias", "main.7.running_mean", "main.7.running_var", "main.9.weight", "main.10.weight", "main.10.bias", "main.10.running_mean", "main.10.running_var", "main.12.weight". 
    Unexpected key(s) in state_dict: "epoch", "model_state_dict", "optimizer_state_dict", "loss".

Upvotes: 2

Views: 5730

Answers (1)

Shai
Shai

Reputation: 114926

You need to access the 'model_state_dict' key inside the loaded checkpoint.
Try:

netG.load_state_dict(torch.load('generator.pth')['model_state_dict'])

You'll probably need to apply the same fix to the discriminator as well.

Upvotes: 4

Related Questions