Hulio Almedo
Hulio Almedo

Reputation: 67

Backtransforming a PyTorch Tensor

I have trained a WGAN on the CelebA dataset in PyTorch following this youtube video. Since I do this on Google Cloud Platform where TensorBoard is not availabe, I save one figure of generated images by the GAN every epoch to see how the GAN is actually doing.

Now, the saved pdf files look sth like this: generated images. Unfortunately, this is not really readable, and I suspect this has to do with the preprocessing I do:

trafo = transforms.Compose(
                    [transforms.Resize(size = (64, 64)),
                    transforms.ToTensor(), 
                    transforms.Normalize( mean = (0.5,), std = (0.5,))]) 

Is there any way to kind of undo this transformation when I save the image?

Currently, I save the image every epoch as follows:

visualization = torchvision.utils.make_grid(
                           tensor = gen(fixed_noise), 
                           nrow = 8, 
                           normalize = False)
plt.savefig("generated_WGAN_" + datetime.now().strftime("%Y%m%d-%H%M%S") + ".pdf")

Also, I should probably mention that in the Jupyter notebook, I get the following warning:

"Clipping input data to the valid range for imshow with RGB data ([0..1]) for floats or [0..255] for integers)."

Upvotes: 1

Views: 1274

Answers (2)

Shai
Shai

Reputation: 114786

It seems like your output pixel values are in range [-1, 1] (please verify this).
Therefore, when you save the images, the negative part is being clipped (as the error message you got suggests).

Try:

visualization = torchvision.utils.make_grid(
                           tensor = torch.clamp(gen(fixed_noise), -1, 1) * 0.5 + 0.5,  # from [-1, 1] -> [0, 1]
                           nrow = 8, 
                           normalize = False)
plt.savefig("generated_WGAN_" + datetime.now().strftime("%Y%m%d-%H%M%S") + ".pdf")

Upvotes: 0

Ivan
Ivan

Reputation: 40628

The torchvision.transform.Normalize function is usually used to standardize data (make mean(data)=0 and std(x)=1) while the normalize option on torchvision.utils.make_grid is used to normalize the data between [0,1] given a range. So no need to implement a function to fix this.

If True, shift the image to the range (0, 1), by the min and max values specified by range. Default: False.

Here you are looking to normalize between 0 and 1. Given a tensor x:

torchvision.utils.make_grid(x, nrow=8, normalize=True, range=(x.min(), x.max()))

Here are some examples of use provided by the PyTorch's documentation.


Back to your original question, I should mention that torchvision.transform.Normalize(mean=0.5, std=0.5) doesn't transform your data such that it has mean=0.5 and std=0.5... Neither will it standardize it to mean=0, std=1. You have to measure the mean and std from your dataset.

torchvision.transform.Normalize simply performs a shift-scale operation. To undo that just unscale-unshift with the same values:

>>> x = torch.rand(64, 3, 100, 100)*torch.rand(64, 1, 1, 1)
>>> x.mean(), x.std()
(tensor(0.2536), tensor(0.2175))

>>> t = T.Normalize(mean, std)
>>> t_inv = lambda x: x*std + mean

>>> x_after = t(x)
>>> x_after.mean(), x_after.std()
(tensor(-0.4928), tensor(0.4350))

>>> x_before = t_inv(x_after)
>>> x_before.mean(), x_before.std()
(tensor(0.2536), tensor(0.2175))

Upvotes: 1

Related Questions