Composer
Composer

Reputation: 255

Upsampling an autoencoder in pytorch

I have defined my autoencoder in pytorch as following (it gives me a 8-dimensional bottleneck at the output of the encoder which works fine torch.Size([1, 8, 1, 1])):

self.encoder = nn.Sequential(
    nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
    nn.ReLU(),
    nn.Conv2d(32, 64, kernel_size=4, stride=2),
    nn.ReLU(),
    nn.Conv2d(64, 8, kernel_size=3, stride=1),
    nn.ReLU(),
    nn.MaxPool2d(7, stride=1)
)

self.decoder = nn.Sequential(
    nn.ConvTranspose2d(8, 64, kernel_size=3, stride=1),
    nn.ReLU(),
    nn.Conv2d(64, 32, kernel_size=4, stride=2),
    nn.ReLU(),
    nn.Conv2d(32, input_shape[0], kernel_size=8, stride=4),
    nn.ReLU(),
    nn.Sigmoid()
)

What I cannot do is train the autoencoder with

def forward(self, x):
    x = self.encoder(x)
    x = self.decoder(x)
    return x

The decoder gives me an error that the decoder cannot upsample the tensor:

Calculated padded input size per channel: (3 x 3). Kernel size: (4 x 4). Kernel size can't be greater than actual input size

Upvotes: 0

Views: 5740

Answers (2)

Composer
Composer

Reputation: 255

I have managed to implement an autoencoder that provides an unsupervised clustering (in my case 8 classes)

This is not an expert solution. I owe thanks to @Szymon Maszke for the suggestions.

self.encoder = nn.Sequential(
    nn.Conv2d(1, 32, kernel_size=8, stride=4),
    nn.ReLU(),
    nn.Conv2d(32, 64, kernel_size=4, stride=2),
    nn.ReLU(),
    nn.Conv2d(64, 2, kernel_size=3, stride=1),
    nn.ReLU(),
    nn.MaxPool2d(6, stride=1)
)

self.decoder = nn.Sequential(
    nn.ConvTranspose2d(2, 64, kernel_size=3, stride=1),
    nn.ReLU(),
    nn.ConvTranspose2d(64, 32, kernel_size=8, stride=4),
    nn.ReLU(),
    nn.ConvTranspose2d(32, 1, kernel_size=8, stride=4)
)

Upvotes: 2

Szymon Maszke
Szymon Maszke

Reputation: 24701

You are not upsampling enough via ConvTranspose2d, shape of your encoder is only 1 pixel (width x height), see this example:

import torch

layer = torch.nn.ConvTranspose2d(8, 64, kernel_size=3, stride=1)
print(layer(torch.randn(64, 8, 1, 1)).shape)

This prints your exact (3,3) shape after upsampling.

You can:

  • Make the kernel smaller - instead of 4 in first Conv2d in decoder use 3 or 2 or even 1
  • Upsample more, for example: torch.nn.ConvTranspose2d(8, 64, kernel_size=7, stride=2) would give you 7x7
  • What I would do personally: downsample less in encoder, so output shape after it is at least 4x4 or maybe 5x5. If you squash your image so much there is no way to encode enough information into one pixel, and even if the code passes the network won't learn any useful representation.

Upvotes: 4

Related Questions