Reputation: 67
I have trained a WGAN on the CelebA dataset in PyTorch following this youtube video. Since I do this on Google Cloud Platform where TensorBoard is not availabe, I save one figure of generated images by the GAN every epoch to see how the GAN is actually doing.
Now, the saved pdf files look sth like this: generated images. Unfortunately, this is not really readable, and I suspect this has to do with the preprocessing I do:
trafo = transforms.Compose(
[transforms.Resize(size = (64, 64)),
transforms.ToTensor(),
transforms.Normalize( mean = (0.5,), std = (0.5,))])
Is there any way to kind of undo this transformation when I save the image?
Currently, I save the image every epoch as follows:
visualization = torchvision.utils.make_grid(
tensor = gen(fixed_noise),
nrow = 8,
normalize = False)
plt.savefig("generated_WGAN_" + datetime.now().strftime("%Y%m%d-%H%M%S") + ".pdf")
Also, I should probably mention that in the Jupyter notebook, I get the following warning:
"Clipping input data to the valid range for imshow with RGB data ([0..1]) for floats or [0..255] for integers)."
Upvotes: 1
Views: 1274
Reputation: 114786
It seems like your output pixel values are in range [-1, 1]
(please verify this).
Therefore, when you save the images, the negative part is being clipped (as the error message you got suggests).
Try:
visualization = torchvision.utils.make_grid(
tensor = torch.clamp(gen(fixed_noise), -1, 1) * 0.5 + 0.5, # from [-1, 1] -> [0, 1]
nrow = 8,
normalize = False)
plt.savefig("generated_WGAN_" + datetime.now().strftime("%Y%m%d-%H%M%S") + ".pdf")
Upvotes: 0
Reputation: 40628
The torchvision.transform.Normalize
function is usually used to standardize data (make mean(data)=0
and std(x)=1
) while the normalize
option on torchvision.utils.make_grid
is used to normalize the data between [0,1]
given a range. So no need to implement a function to fix this.
If
True
, shift the image to the range(0, 1)
, by the min and max values specified by range. Default:False
.
Here you are looking to normalize between 0
and 1
. Given a tensor x
:
torchvision.utils.make_grid(x, nrow=8, normalize=True, range=(x.min(), x.max()))
Here are some examples of use provided by the PyTorch's documentation.
Back to your original question, I should mention that torchvision.transform.Normalize(mean=0.5, std=0.5)
doesn't transform your data such that it has mean=0.5
and std=0.5
... Neither will it standardize it to mean=0
, std=1
. You have to measure the mean and std from your dataset.
torchvision.transform.Normalize
simply performs a shift-scale operation. To undo that just unscale-unshift with the same values:
>>> x = torch.rand(64, 3, 100, 100)*torch.rand(64, 1, 1, 1)
>>> x.mean(), x.std()
(tensor(0.2536), tensor(0.2175))
>>> t = T.Normalize(mean, std)
>>> t_inv = lambda x: x*std + mean
>>> x_after = t(x)
>>> x_after.mean(), x_after.std()
(tensor(-0.4928), tensor(0.4350))
>>> x_before = t_inv(x_after)
>>> x_before.mean(), x_before.std()
(tensor(0.2536), tensor(0.2175))
Upvotes: 1