Reputation: 6799
I am trying to plot the 3 images(img.jpg) belonging to the 3 classes in my dataset with the following code:
plt.figure(figsize=(10, 10))
for images, labels in train_dataset.take(1):
for i in range(9):
ax = plt.subplot(3, 3, i + 1)
plt.imshow(images[i].numpy().astype("uint8"))
plt.title(class_names[labels[i]])
plt.axis("off")
But I am getting the following error :
InvalidArgumentError: assertion failed: [Unable to decode bytes as JPEG, PNG, GIF, or BMP]
EDIT:
train_dataset:
train_dataset = tf.keras.preprocessing.image_dataset_from_directory(
directory=TRAIN_DIR,
labels="inferred",
label_mode="int",
class_names=["0","5","10"],
batch_size=BATCH_SIZE,
image_size=(TARGETX, TARGETY),
shuffle=True,
seed=SEED,
validation_split=None,
subset=None,
interpolation="bilinear",
follow_links=False,
)
Upvotes: 0
Views: 674
Reputation: 6799
So I got the outcome I wanted issue was that my images are nested deeply in their folders so I needed to extract their paths before plotting them.
IMG_0 = [files for root, directories, files in os.walk(test_img_0, topdown=False)]
IMG_5 = [files for root, directories, files in os.walk(test_img_5, topdown=False)]
IMG_10 = [files for root, directories, files in os.walk(test_img_10, topdown=False)]
def show_images(images, cols = 1, titles = None):
"""Display a list of images in a single figure with matplotlib.
Parameters
---------
images: List of np.arrays compatible with plt.imshow.
cols (Default = 1): Number of columns in figure (number of rows is
set to np.ceil(n_images/float(cols))).
titles: List of titles corresponding to each image. Must have
the same length as titles.
"""
assert((titles is None)or (len(images) == len(titles)))
n_images = len(images)
if titles is None: titles = ['Image (%d)' % i for i in range(1,n_images + 1)]
fig = plt.figure()
for n, (image, title) in enumerate(zip(images, titles)):
a = fig.add_subplot(cols, np.ceil(n_images/float(cols)), n + 1)
if image.ndim == 2:
plt.gray()
plt.imshow(image)
a.set_title(title)
fig.set_size_inches(np.array(fig.get_size_inches()) * n_images)
plt.show()
Then I got the absolute paths to the images and sent them as a list into the function:
img1 = mpimg.imread(os.path.join(test_img_0, IMG_0[0][0]))
img2 = mpimg.imread(os.path.join(test_img_5, IMG_5[0][0]))
img3 = mpimg.imread(os.path.join(test_img_10, IMG_10[0][0]))
show_images([img1, img2, img3], cols = 2, titles=["Class 0", "Class 5", "Class 10"])
Upvotes: 0
Reputation: 3564
There is nothing wrong with the plotting function. This most likely an issue with your dataset.
See this issue.
If you want to check whether the plotting function is correct.
import pathlib
dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = tf.keras.utils.get_file(origin=dataset_url,
fname='flower_photos',
untar=True)
data_dir = pathlib.Path(data_dir)
batch_size = 32
img_height = 180
img_width = 180
train_dataset = tf.keras.preprocessing.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset="training",
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size)
class_names = train_dataset.class_names
Note: This will download the flowers dataset and build a tf.dataset object. You can pass the train_dataset to plotting function to see if there are any issues.
Checks for your dataset:
Upvotes: 1