Reputation: 113
I am trying to load an image from file through tf.data.datasets and then display the image with matplotlib. I have several more files to expand this too once i understand how loading a single image works. I do not understand what is going wrong here. What is causing the error below. How do i correct this code so that i can display the image.
I am using tensorflow 1.14
import tensorflow as tf
import matplotlib.pyplot as plt
filename = tf.constant(['D:/Datasets/The Oxford-IIIT Pet Dataset (Segmentation)/images/Abyssinian_1.jpg'])
dataset = tf.data.Dataset.from_tensor_slices(filename)
def format_image(image_dir):
image = tf.read_file(image_dir)
image = tf.image.decode_jpeg(image, channels=3)
image = tf.image.convert_image_dtype(image, tf.float32)
image = tf.image.resize_image_with_pad(image, 256, 256, align_corners=True)
return image
dataset = dataset.map(format_image)
dataset = dataset.batch(1)
iterator = dataset.make_initializable_iterator()
image = iterator.get_next()
with tf.Session() as sess:
sess.run([iterator.initializer])
decoded_image = sess.run([image])
plt.imshow(decoded_image)
plt.show()
I am getting the error:
Traceback (most recent call last):
File "C:/Users/g/Deeplab_custom/readinganimage.py", line 24, in <module>
plt.imshow(decoded_image)
File "C:\Users\s\AppData\Local\Programs\Python\Python36\lib\site-packages\matplotlib\pyplot.py", line 2699, in imshow
None else {}), **kwargs)
File "C:\Users\s\AppData\Local\Programs\Python\Python36\lib\site-packages\matplotlib\__init__.py", line 1810, in inner
return func(ax, *args, **kwargs)
File "C:\Users\s\AppData\Local\Programs\Python\Python36\lib\site-packages\matplotlib\axes\_axes.py", line 5494, in imshow
im.set_data(X)
File "C:\Users\s\AppData\Local\Programs\Python\Python36\lib\site-packages\matplotlib\image.py", line 638, in set_data
raise TypeError("Invalid dimensions for image data")
TypeError: Invalid dimensions for image data
Upvotes: 0
Views: 1861
Reputation: 11333
The way your currently have the code, it returns a [1,1,256,256,3]
output. These dimensions are [introduced by using square brackets in sess.run, batch dimension, height, width, channels]
. Which is not understood by matplotlib
. matplotlib
needs a [height, width, channels]
array.
So in your case, what you can do is, the following.
with tf.Session() as sess:
sess.run([iterator.initializer])
decoded_image = sess.run([image])
decoded_image = sess.run(image)
plt.imshow(decoded_image[0][0])
But you are introducing unnecessary dimensions by using sess.run([image])
instead do sess.run(image)
. And the following.
with tf.Session() as sess:
sess.run([iterator.initializer])
decoded_image = sess.run(image)
plt.imshow(decoded_image[0])
plt.show()
Upvotes: 1