Hekes Pekes
Hekes Pekes

Reputation: 1325

Python matplotlib, invalid shape for image data

Currently I have this code to show three images:

imshow(image1, title='1')
imshow(image2, title='2')
imshow(image3, title='3')

And it works fine. But I am trying to put them all three in a row instead of column.

Here is the code I have tried:

f = plt.figure()
f.add_subplot(1,3,1)
plt.imshow(image1)
f.add_subplot(1,3,2)
plt.imshow(image2)
f.add_subplot(1,3,3)
plt.imshow(image3)

It throws

TypeError: can't convert CUDA tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

If I do

f = plt.figure()
f.add_subplot(1,3,1)
plt.imshow(image1.cpu())
f.add_subplot(1,3,2)
plt.imshow(image2.cpu())
f.add_subplot(1,3,3)
plt.imshow(image3.cpu())

It throws

TypeError: Invalid shape (1, 3, 128, 128) for image data

How should I fix this or is there an easier way to implement it?

Upvotes: 4

Views: 24319

Answers (1)

A. Maman
A. Maman

Reputation: 972

The matplotlib function 'imshow' gets 3-channel pictures as (h, w, 3) as you can see in the documentation.

It seems that you passed a "batch" of single image (the first dimention) of three channels (second dimention) of the image (h and w are the third and forth dimention).

You need to reshape or view your image (after converting to cpu, try to use:

image1.squeeze().permute(1,2,0)

The result will be an image of the desired shape (128, 128, 3).

The squeeze() function will remove the first dimention. And the premute() function will transpose the dimenstion where the first will shift to the third position and the two other will shift to the beginning.

Also, have a look here for further talk on the GPU and CPU issues: link

Hope that helps.

Upvotes: 12

Related Questions