Reputation: 168
I'm using the CIFAR-10 pre-trained VAE from lightning-bolts. It should be able to regenerate images with the quality shown on this picture taken from the docs (LHS are the real images, RHS are the generated)
However, when I write a simple script that loads the model, the weights, and tests it over the training set, I get a much worse reconstruction (top row are real images, bottom row are the generated ones):
Here is a link to a self-contained colab notebook that reproduces the steps I've followed to produce the pictures.
Am I doing something wrong on my inference process? Could it be that the weights are not as "good" as the docs claim?
Thanks!
Upvotes: 1
Views: 391
Reputation: 633
First, the image from the docs you show is for the AE, not the VAE. The results for the VAE look much worse:
https://pl-bolts-weights.s3.us-east-2.amazonaws.com/vae/vae-cifar10/vae_output.png
Second, the docs state "Both input and generated images are normalized versions as the training was done with such images." So when you load the data you should specify normalize=True. When you plot your data, you will need to 'unnormalize' the data as well:
from pl_bolts.datamodules import CIFAR10DataModule
from pl_bolts.models.autoencoders import VAE
from pytorch_lightning import Trainer
import matplotlib.pyplot as plt
import numpy as np
import torch
from torchvision import transforms
torch.manual_seed(17)
np.random.seed(17)
vae = VAE(32, lr=0.00001)
vae = vae.from_pretrained("cifar10-resnet18")
dm = CIFAR10DataModule(".", normalize=True)
dm.prepare_data()
dm.setup("fit")
dataloader = dm.train_dataloader()
print(dm.default_transforms())
mean = torch.tensor(dm.default_transforms().transforms[1].mean)
std = torch.tensor(dm.default_transforms().transforms[1].std)
unnormalize = transforms.Normalize((-mean / std).tolist(), (1.0 / std).tolist())
X, _ = next(iter(dataloader))
vae.eval()
X_hat = vae(X)
fig, axes = plt.subplots(2, 10, figsize=(10, 2))
for i in range(10):
ax_real = axes[0][i]
ax_real.imshow(np.transpose(unnormalize(X[i]), (1, 2, 0)))
ax_real.get_xaxis().set_visible(False)
ax_real.get_yaxis().set_visible(False)
ax_gen = axes[1][i]
ax_gen.imshow(np.transpose(unnormalize(X_hat[i]).detach().numpy(), (1, 2, 0)))
ax_gen.get_xaxis().set_visible(False)
ax_gen.get_yaxis().set_visible(False)
Which gives something like this:
Without normalization it looks like:
Upvotes: 3