Lance4129
Lance4129

Reputation: 105

Pytorch transform.ToTensor() changes image

I want to convert images to tensor using torchvision.transforms.ToTensor(). After processing, I printed the image but the image was not right. Here is my code:

trans = transforms.Compose([
    transforms.ToTensor()])

demo = Image.open(img) 
demo_img = trans(demo)
demo_array = demo_img.numpy()*255
print(Image.fromarray(demo_array.astype(np.uint8)))

The original image is:

original image

After processing, it looks like:

after processing

Upvotes: 10

Views: 49927

Answers (1)

David
David

Reputation: 8318

It seems that the problem is with the channel axis.

If you look at torchvision.transforms docs, especially on ToTensor()

Converts a PIL Image or numpy.ndarray (H x W x C) in the range [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]

So once you perform the transformation and return to numpy.array your shape is: (C, H, W) and you should change the positions, you can do the following:

demo_array = np.moveaxis(demo_img.numpy()*255, 0, -1)

This will transform the array to shape (H, W, C) and then when you return to PIL and show it will be the same image.

So in total:

import numpy as np
from PIL import Image
from torchvision import transforms

trans = transforms.Compose([transforms.ToTensor()])

demo = Image.open(img) 
demo_img = trans(demo)
demo_array = np.moveaxis(demo_img.numpy()*255, 0, -1)
print(Image.fromarray(demo_array.astype(np.uint8)))

Upvotes: 20

Related Questions