RR_28023
RR_28023

Reputation: 168

Pretrained lightning-bolts VAE not doing proper inference on training dataset

I'm using the CIFAR-10 pre-trained VAE from lightning-bolts. It should be able to regenerate images with the quality shown on this picture taken from the docs (LHS are the real images, RHS are the generated)

enter image description here

However, when I write a simple script that loads the model, the weights, and tests it over the training set, I get a much worse reconstruction (top row are real images, bottom row are the generated ones):

enter image description here

Here is a link to a self-contained colab notebook that reproduces the steps I've followed to produce the pictures.

Am I doing something wrong on my inference process? Could it be that the weights are not as "good" as the docs claim?

Thanks!

Upvotes: 1

Views: 391

Answers (1)

cnash
cnash

Reputation: 633

First, the image from the docs you show is for the AE, not the VAE. The results for the VAE look much worse:
enter image description here enter image description here
https://pl-bolts-weights.s3.us-east-2.amazonaws.com/vae/vae-cifar10/vae_output.png

Second, the docs state "Both input and generated images are normalized versions as the training was done with such images." So when you load the data you should specify normalize=True. When you plot your data, you will need to 'unnormalize' the data as well:

from pl_bolts.datamodules import CIFAR10DataModule
from pl_bolts.models.autoencoders import VAE
from pytorch_lightning import Trainer
import matplotlib.pyplot as plt
import numpy as np
import torch
from torchvision import transforms

torch.manual_seed(17)
np.random.seed(17)

vae = VAE(32, lr=0.00001)
vae = vae.from_pretrained("cifar10-resnet18")

dm = CIFAR10DataModule(".", normalize=True)
dm.prepare_data()
dm.setup("fit")
dataloader = dm.train_dataloader()

print(dm.default_transforms())
mean = torch.tensor(dm.default_transforms().transforms[1].mean)
std = torch.tensor(dm.default_transforms().transforms[1].std)
unnormalize = transforms.Normalize((-mean / std).tolist(), (1.0 / std).tolist())

X, _ = next(iter(dataloader))
vae.eval()
X_hat = vae(X)

fig, axes = plt.subplots(2, 10, figsize=(10, 2))
for i in range(10):  
  ax_real = axes[0][i]
  ax_real.imshow(np.transpose(unnormalize(X[i]), (1, 2, 0)))
  ax_real.get_xaxis().set_visible(False)
  ax_real.get_yaxis().set_visible(False)

  ax_gen = axes[1][i]
  ax_gen.imshow(np.transpose(unnormalize(X_hat[i]).detach().numpy(), (1, 2, 0)))
  ax_gen.get_xaxis().set_visible(False)
  ax_gen.get_yaxis().set_visible(False)

Which gives something like this: pytorch-lightning VAE reconstruction unnormalized

Without normalization it looks like: enter image description here

Upvotes: 3

Related Questions