Dalek
Dalek

Reputation: 4318

RuntimeError: mat1 and mat2 shapes cannot be multiplied (4000x20 and 200x441)

The architecture of the decoder of my variational autoencoder is given in the snippet below

class ConvolutionalVAE(nn.Module):
    
    def __init__(self, nchannel, base_channels, z_dim, hidden_dim,  device, img_width, batch_size):
        super(ConvolutionalVAE, self).__init__()

        self.nchannel   = nchannel
        self.base_channels = base_channels
        self.z_dim      = z_dim
        self.hidden_dim = hidden_dim
        self.device     = device
        self.img_width  = img_width
        self.batch_size = batch_size
        self.enc_kernel = 4
        self.enc_stride = 2
        self._to_linear = None
        ########################
        # ENCODER-CONVOLUTION LAYERS
        self.conv0       = nn.Conv2d(nchannel, base_channels, self.enc_kernel, stride=self.enc_stride)
        self.bn2d_0      = nn.BatchNorm2d(self.base_channels)
        self.LeakyReLU_0 = nn.LeakyReLU(0.2)
        out_width        = np.floor((self.img_width - self.enc_kernel) / self.enc_stride + 1)
        self.conv1       = nn.Conv2d(base_channels, base_channels*2, self.enc_kernel, stride=self.enc_stride)
        self.bn2d_1      = nn.BatchNorm2d(base_channels*2)
        self.LeakyReLU_1 = nn.LeakyReLU(0.2)
        out_width        = np.floor((out_width - self.enc_kernel) / self.enc_stride + 1)
        self.conv2       = nn.Conv2d(base_channels*2, base_channels*4, self.enc_kernel, stride=self.enc_stride)
        self.bn2d_2      = nn.BatchNorm2d(base_channels*4)
        self.LeakyReLU_2 = nn.LeakyReLU(0.2)
        out_width        = np.floor((out_width - self.enc_kernel) / self.enc_stride + 1)
        self.conv3       = nn.Conv2d(base_channels*4, base_channels*8, self.enc_kernel, stride=self.enc_stride)
        self.bn2d_3      = nn.BatchNorm2d(base_channels*8)
        self.LeakyReLU_3 = nn.LeakyReLU(0.2)
        out_width        = int(np.floor((out_width - self.enc_kernel) / self.enc_stride + 1))
        ########################
        #ENCODER-USING FULLY CONNECTED LAYERS
        #THE LATENT SPACE (Z)
        self.flatten     = nn.Flatten()
        self.fc0         = nn.Linear((out_width**2) * base_channels * 8, base_channels*8*4*4, bias=False)
        self.bn1d        = nn.BatchNorm1d(base_channels*8*4*4)
        self.fc1         = nn.Linear(base_channels*8*4*4, hidden_dim, bias=False)
        self.bn1d_1      = nn.BatchNorm1d(hidden_dim)
        # mean of z

        self.fc2         = nn.Linear(hidden_dim, z_dim, bias=False)
        self.bn1d_2      = nn.BatchNorm1d(z_dim)
        # variance of z

        self.fc3         = nn.Linear(hidden_dim, z_dim, bias=False)
        self.bn1d_3      = nn.BatchNorm1d(z_dim)
        ########################
        # DECODER: 
        #  P(X|Z)
        conv2d_transpose_kernels, conv2d_transpose_input_width = self.determine_decoder_params(self.z_dim, self.img_width)
        self.conv2d_transpose_input_width = conv2d_transpose_input_width
        self.px_z_fc_0   = nn.Linear(self.z_dim, conv2d_transpose_input_width ** 2)
        self.px_z_bn1d_0 = nn.BatchNorm1d(conv2d_transpose_input_width ** 2)
        self.px_z_fc_1   = nn.Linear(conv2d_transpose_input_width ** 2, conv2d_transpose_input_width ** 2)
        #self.unflatten = nn.Unflatten(1, (1, conv2d_transpose_input_width, conv2d_transpose_input_width))
        self.conv2d_transpose_input_width = conv2d_transpose_input_width
        self.px_z_conv_transpose2d = nn.ModuleList()
        self.px_z_bn2d   = nn.ModuleList()
        self.n_conv2d_transpose = len(conv2d_transpose_kernels)
        self.px_z_conv_transpose2d.append(nn.ConvTranspose2d(1, self.base_channels * (self.n_conv2d_transpose - 1),
                                                             kernel_size=conv2d_transpose_kernels[0], stride=2))
        self.px_z_bn2d.append(nn.BatchNorm2d(self.base_channels * (self.n_conv2d_transpose - 1)))
        self.px_z_LeakyReLU = nn.ModuleList()
        self.px_z_LeakyReLU.append(nn.LeakyReLU(0.2))
        
        for i in range(1, self.n_conv2d_transpose - 1):
            self.px_z_conv_transpose2d.append(nn.ConvTranspose2d(self.base_channels * (self.n_conv2d_transpose - i),
                                                                 self.base_channels*(self.n_conv2d_transpose - i - 1),
                                                                 kernel_size=conv2d_transpose_kernels[i], stride=2))
            self.px_z_bn2d.append(nn.BatchNorm2d(self.base_channels * (self.n_conv2d_transpose - i - 1)))
            self.px_z_LeakyReLU.append(nn.LeakyReLU(0.2))
        self.px_z_conv_transpose2d.append(nn.ConvTranspose2d(self.base_channels, self.nchannel,
                                                             kernel_size=conv2d_transpose_kernels[-1], stride=2))
                self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        self.to(device=self.device)


    def decoder(self, z_input):
        #Generate X: P(X|Z)
        h  = F.relu(self.px_z_bn1d_0(self.px_z_fc_0(z_input)))
        flattened_h = self.px_z_fc_1(h)
        h = flattened_h.view(flattened_h.size()[0], 1, self.conv2d_transpose_input_width, self.conv2d_transpose_input_width)
        for i in range(self.n_conv2d_transpose - 1):
            h = self.px_z_LeakyReLU[i](self.px_z_bn2d[i](self.px_z_conv_transpose2d[i](h)))
        x_recons_mean_flat = torch.sigmoid(self.px_z_conv_transpose2d[self.n_conv2d_transpose - 1](h))
        return x_recons_mean_flat

running my code to reconstruct the images:

all_z = []
for d in range(self.z_dim):
   temp_z = torch.cat( [self.z_sample_list[k][:, d].unsqueeze(1) for k in range(self.K)], dim=1)
   print(f'size of each z component dimension: {temp_z.size()}')
   all_z.append(torch.mm(temp_z.transpose(1, 0), components).unsqueeze(1))
out       = torch.cat( all_z,1)
x_samples       = self.decoder(out) 

I got this error message:

size of z dimension: 200
size of each z component dimension: torch.Size([50, 20])
size of all z component dimension: torch.Size([20, 200, 20])
x_samples = self.decoder(out)
File "VAE.py", line 241, in decoder
h  = F.relu(self.px_z_bn1d_0(self.px_z_fc_0(z_input)))
File "/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/anaconda3/lib/python3.8/site-packages/torch/nn/modules/linear.py", line 96, in forward
return F.linear(input, self.weight, self.bias)
File "/anaconda3/lib/python3.8/site-packages/torch/nn/functional.py", line 1847, in linear
return torch._C._nn.linear(input, weight, bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (4000x20 and 200x441)

Update I changed my code slightly to this

all_z = []
for d in range(self.z_dim):
    temp_z = torch.cat( [self.z_sample_list[k][:, d].unsqueeze(1) for k in range(self.K)], dim=1)
    all_z.append(torch.mm(temp_z.transpose(1, 0), components).unsqueeze(1))
out       = torch.cat( all_z,1)
print(f'size of all z component dimension: {out.size()}')
out = F.pad(input=out, pad=(1, 0, 0,0, 0, 1), mode='constant', value=0)
print(f'new size of all z component dimension after padding: {out.size()}')
out = rearrange(out, 'd0 d1 d2 -> d1 (d0 d2)')
x_samples       = self.decoder(out)

Now the new error is

x_samples       = self.decoder(out)
File "VAE.py", line 243, in decoder
h  = F.relu(self.px_z_bn1d_0(self.px_z_fc_0(z_input)))
File "/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/anaconda3/lib/python3.8/site-packages/torch/nn/modules/linear.py", line 96, in forward
return F.linear(input, self.weight, self.bias)
File "/anaconda3/lib/python3.8/site-packages/torch/nn/functional.py", line 1847, in linear
return torch._C._nn.linear(input, weight, bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (200x441 and 200x441)

Any suggestion to fix this error?

Upvotes: 0

Views: 849

Answers (1)

QuantumMecha
QuantumMecha

Reputation: 1541

Matrix multiplication requires the 2 inner dimensions to be the same. You are getting the error: RuntimeError: mat1 and mat2 shapes cannot be multiplied (200x441 and 200x441) because your inner dimensions don't line up.

for example:

shape(200, 441) * shape(441, 200) # works
shape(441, 200) * shape(200, 441) # works
shape(200, 441) * shape(200, 441) # doesn't work, this is why you are getting your error

# in general
shape(x, y) * shape(y, z) # works

To make the inner dimensions match, just take the transpose of one or the other:

shape(200, 441) * shape(200, 441).T # works
# or
shape(200, 441).T * shape(200, 441) # works

# since the transpose works by swapping the dimensions:
shape(200, 441).T = shape(441, 200)

Upvotes: 1

Related Questions