DLH
DLH

Reputation: 199

Why are the parameters of this PyTorch AutoEncoder hardcoded this way?

Hi I am trying to understand how the following PyTorch AutoEncoder code works. The code below uses the MNIST dataset which is 28X28. My question is how the nn.Linear(128,3) parameters where chosen?

I have a dataset which is 512X512 and I would like to modify the code for this AutoEncoder to support.

class LitAutoEncoder(pl.LightningModule):

def __init__(self):
    super().__init__()
    self.encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3))
    self.decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28))

def forward(self, x):
    # in lightning, forward defines the prediction/inference actions
    embedding = self.encoder(x)
    return embedding

def training_step(self, batch, batch_idx):
    # training_step defined the train loop. It is independent of forward
    x, y = batch
    x = x.view(x.size(0), -1)
    z = self.encoder(x)
    x_hat = self.decoder(z)
    loss = F.mse_loss(x_hat, x)
    return loss

def configure_optimizers(self):
    optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
    return optimizer

Upvotes: 0

Views: 406

Answers (1)

ihdv
ihdv

Reputation: 2287

I am assuming input image data are in this shape: x.shape == [bs, 1, h, w], where bs is batch size. Then, x is first viewed as [bs, h*w], i.e. [bs, 28*28]. This means all pixels in an image are flattened into a 1D vector.

Then in the encoder:

  • nn.Linear(28*28, 128) takes flattened input of size [bs, 28*28] and outputs intermediate result of size [bs, 128]
  • nn.Linear(128, 3): [bs, 128] -> [bs, 3]

Then in the decoder:

  • nn.Linear(3, 128): [bs, 3] -> [bs, 128]
  • nn.Linear(128, 28*28): [bs, 128] -> [bs, 28*28]

The final output is then matched against the input.

If you want to use the exact architecture for your 512x512 images, simply change every occurrence of 28*28 in the code to 512*512. However, this is a quite infeasible choice, for these reasons:

  • For MNIST images, nn.Linear(28*28, 128) contains 28x28x128+128=100480 parameters, while for your images nn.Linear(512*512, 128) contains 512x512x128+128=33554560 parameters. The size is too large, and it may lead to overfitting
  • The intermediate data [bs, 3] uses only 3 floats to encode a 512x512 image. I don't think you can recover anything with such compression

I'd suggest looking up convolutional architectures for you purpose

Upvotes: 1

Related Questions