binru
binru

Reputation: 43

PyTorch: 'ToTensor()' turns color image into 9 grayscale pictures

I have found that when I use 'ToTensor' to a images, one image becomes 9 displayed.I checked the official documents but couldn't find the reason. so why a picture become 9 pictures???questioon as the following figure.

a = plt.imread('test.jpg')
plt.imshow(a)
plt.show()

enter image description here

transform = transforms.Compose([transforms.ToTensor()])
b = transform(a)
b = b.view(375,500,3)
plt.imshow(b)

enter image description here

Upvotes: 4

Views: 542

Answers (1)

Nicolas Gervais
Nicolas Gervais

Reputation: 36684

When you use transforms.ToTensor(), by default it changes the input arrays from HWC to CHW order. For plotting, you'll need to push back the channels to the last dimension.

plt.imshow(b.permute(2, 0, 1))

Upvotes: 3

Related Questions