Reputation: 71
I'm training a 3D-GAN to generate MRI volumes. I defined my model as follows:
###### Definition of the generator ######
class Generator(nn.Module):
def __init__(self, ngpu):
#super() makes Generator a subclass of nn.Module, so that it inherites all the methods of nn.Module
super(Generator, self).__init__()
self.ngpu = ngpu
#we can use Sequential() since the output of one layer is the input of the next one
self.main = nn.Sequential(
# input is latent vector z, going into a convolution
nn.ConvTranspose3d(nz, ngf * 8, 4, stride=2, padding=0, bias=True), # try to put kernel = (batch_size,4,4,4,512)
nn.BatchNorm3d(ngf * 8),
nn.ReLU(True), #True means that it does the operation inplace, default is False
nn.ConvTranspose3d(ngf * 8, ngf * 4, 4, stride=2, padding=1, bias=True), # try to put kernel = (batch_size,8,8,8,256)
nn.BatchNorm3d(ngf * 4),
nn.ReLU(True),
nn.ConvTranspose3d(ngf * 4, ngf * 2, 4, stride=2, padding=1, bias=True), # try to put kernel = (batch_size,16,16,16,128)
nn.BatchNorm3d(ngf * 2),
nn.ReLU(True),
nn.ConvTranspose3d( ngf * 2, ngf, 4, stride=2, padding=1, bias=True), # try to put kernel = (batch_size,32,32,32,64)
nn.BatchNorm3d(ngf),
nn.ReLU(True),
nn.ConvTranspose3d(ngf, nc, 4, stride=2, padding=1, bias=True), # try to put kernel = (batch_size,64,64,64,1)
nn.Sigmoid()
)
def forward(self, x):
return self.main(x)
###### Definition of the Discriminator ######
class Discriminator(nn.Module):
def __init__(self, ngpu):
super(Discriminator, self).__init__()
self.ngpu = ngpu
self.main = nn.Sequential(
nn.Conv3d(nc, ndf, 4, stride=2, padding=1, bias=True),
nn.BatchNorm3d(ndf),
nn.LeakyReLU(leak_value, inplace=True),
nn.Conv3d(ndf, ndf * 2, 4, stride=2, padding=1, bias=True),
nn.BatchNorm3d(ndf * 2),
nn.LeakyReLU(leak_value, inplace=True),
nn.Conv3d(ndf * 2, ndf * 4, 4, stride=2, padding=1, bias=True),
nn.BatchNorm3d(ndf * 4),
nn.LeakyReLU(leak_value, inplace=True),
nn.Conv3d(ndf * 4, ndf * 8, 4, stride=2, padding=1, bias=True),
nn.BatchNorm3d(ndf * 8),
nn.LeakyReLU(leak_value, inplace=True),
nn.Conv3d(ndf * 8, nc, 4, stride=1, padding=0, bias=True),
nn.Sigmoid()
)
def forward(self, x):
return self.main(x)
I then train the model and save it. When loading the model for evaluation and testing I get the following error:
RuntimeError: Error(s) in loading state_dict for Generator: size mismatch for main.0.weight: copying a param with shape torch.Size([64, 1, 4, 4, 4]) from checkpoint, the shape in current model is torch.Size([200, 512, 4, 4, 4]). size mismatch for main.0.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([512]). size mismatch for main.1.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([512]). size mismatch for main.1.running_mean: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([512]). size mismatch for main.1.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([512]). size mismatch for main.1.running_var: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([512]). size mismatch for main.3.weight: copying a param with shape torch.Size([128, 64, 4, 4, 4]) from checkpoint, the shape in current model is torch.Size([512, 256, 4, 4, 4]). size mismatch for main.3.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]). size mismatch for main.4.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]). size mismatch for main.4.running_mean: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]). size mismatch for main.4.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]). size mismatch for main.4.running_var: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]). size mismatch for main.6.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]). size mismatch for main.7.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]). size mismatch for main.7.running_mean: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]). size mismatch for main.7.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]). size mismatch for main.7.running_var: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]). size mismatch for main.9.weight: copying a param with shape torch.Size([512, 256, 4, 4, 4]) from checkpoint, the shape in current model is torch.Size([128, 64, 4, 4, 4]). size mismatch for main.9.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([64]). size mismatch for main.10.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([64]). size mismatch for main.10.running_mean: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([64]). size mismatch for main.10.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([64]). size mismatch for main.10.running_var: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([64]). size mismatch for main.12.weight: copying a param with shape torch.Size([1, 512, 4, 4, 4]) from checkpoint, the shape in current model is torch.Size([64, 1, 4, 4, 4]).
What I'm I doing wrong?
Thanks in advance!
Upvotes: 4
Views: 18298
Reputation: 31
The model you loaded and the target model is not identical, so the error raise to inform about mismatches of size, layers, check again your code, or your saved model may not be saved properly
Upvotes: 3