Abhik
Abhik

Reputation: 1940

Plot colour image from a numpy array that has 3 channels

in my Jupyter notebook I am trying to display an image that I am iterating on through Keras. The code I am using is as below

def plotImages(path, num):
 batchGenerator = file_utils.fileBatchGenerator(path+"train/", num)
 imgs,labels = next(batchGenerator)
 fig = plt.figure(figsize=(224, 224))
 plt.gray()
 for i in range(num):
    sub = fig.add_subplot(num, 1, i + 1)
    sub.imshow(imgs[i,0], interpolation='nearest')

But this only plots single channel, so my image is grayscale. How do I use the 3 channels to output a colour image plot. ?

Upvotes: 11

Views: 27125

Answers (2)

HRn
HRn

Reputation: 1

It will work for you: Try putting the channels last by permitting,

image.permute(2 , 3 , 1 , 0)

and then remove the image index by np.squeeze():

plt.imshow((np.squeeze(image.permute(2 , 3 , 1 , 0))))

Upvotes: 0

Suever
Suever

Reputation: 65460

If you want to display an RGB image, you have to supply all three channels. Based on your code, you are instead displaying just the first channel so matplotlib has no information to display it as RGB. Instead it will map the values to the gray colormap since you've called plt.gray()

Instead, you'll want to pass all channels of the RGB image to imshow and then the true color display is used and the colormap of the figure is disregarded

sub.imshow(imgs, interpolation='nearest')

Update

Since imgs is actually 2 x 3 x 224 x 224, you'll want to index into imgs and permute the dimensions to be 224 x 224 x 3 prior to displaying the image

im2display = imgs[1].transpose((1,2,0))
sub.imshow(im2display, interpolation='nearest')

Upvotes: 11

Related Questions