Cobes
Cobes

Reputation: 113

how to read a .jpg into tensorflow dataset and display image with session

I am trying to load an image from file through tf.data.datasets and then display the image with matplotlib. I have several more files to expand this too once i understand how loading a single image works. I do not understand what is going wrong here. What is causing the error below. How do i correct this code so that i can display the image.

I am using tensorflow 1.14

import tensorflow as tf
import matplotlib.pyplot as plt

filename = tf.constant(['D:/Datasets/The Oxford-IIIT Pet Dataset (Segmentation)/images/Abyssinian_1.jpg'])

dataset = tf.data.Dataset.from_tensor_slices(filename)

def format_image(image_dir):
    image = tf.read_file(image_dir)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.convert_image_dtype(image, tf.float32)
    image = tf.image.resize_image_with_pad(image, 256, 256, align_corners=True)
    return image

dataset = dataset.map(format_image)
dataset = dataset.batch(1)

iterator = dataset.make_initializable_iterator()
image = iterator.get_next()

with tf.Session() as sess:
    sess.run([iterator.initializer])
    decoded_image = sess.run([image])
    plt.imshow(decoded_image)
    plt.show()

I am getting the error:

Traceback (most recent call last):
  File "C:/Users/g/Deeplab_custom/readinganimage.py", line 24, in <module>
    plt.imshow(decoded_image)
  File "C:\Users\s\AppData\Local\Programs\Python\Python36\lib\site-packages\matplotlib\pyplot.py", line 2699, in imshow
    None else {}), **kwargs)
  File "C:\Users\s\AppData\Local\Programs\Python\Python36\lib\site-packages\matplotlib\__init__.py", line 1810, in inner
    return func(ax, *args, **kwargs)
  File "C:\Users\s\AppData\Local\Programs\Python\Python36\lib\site-packages\matplotlib\axes\_axes.py", line 5494, in imshow
    im.set_data(X)
  File "C:\Users\s\AppData\Local\Programs\Python\Python36\lib\site-packages\matplotlib\image.py", line 638, in set_data
    raise TypeError("Invalid dimensions for image data")
TypeError: Invalid dimensions for image data

Upvotes: 0

Views: 1861

Answers (1)

thushv89
thushv89

Reputation: 11333

The way your currently have the code, it returns a [1,1,256,256,3] output. These dimensions are [introduced by using square brackets in sess.run, batch dimension, height, width, channels]. Which is not understood by matplotlib. matplotlib needs a [height, width, channels] array.

So in your case, what you can do is, the following.

with tf.Session() as sess:
    sess.run([iterator.initializer])
    decoded_image = sess.run([image])
    decoded_image = sess.run(image)
    plt.imshow(decoded_image[0][0])

But you are introducing unnecessary dimensions by using sess.run([image]) instead do sess.run(image). And the following.

with tf.Session() as sess:
    sess.run([iterator.initializer])
    decoded_image = sess.run(image)
    plt.imshow(decoded_image[0])
    plt.show()

Upvotes: 1

Related Questions