vegiv
vegiv

Reputation: 152

showing images from a tensorflow dataset class

Im trying to make a training set and a validation set from two folders. One with images of open hands and another with closed hands. file structure is as follows: images->closed,open. Im trying to use tensorflow.keras.preprocessing.image_dataset_from_directory which im not familiar with. Im using python 3.7.4 and tf 2.3.1

import tensorflow as tf
import numpy as np

path = './images'

dataset_train = tf.keras.preprocessing.image_dataset_from_directory(
    path,
    labels="inferred",
    label_mode="categorical",
    class_names=None,
    color_mode="grayscale",
    batch_size=32,
    image_size=(50, 50),
    shuffle=True,
    seed=123,
    validation_split=0.1,
    subset='training',
    interpolation="bilinear",
    follow_links=False,
)
dataset_test = tf.keras.preprocessing.image_dataset_from_directory(
    path,
    labels="inferred",
    label_mode="categorical",
    class_names=None,
    color_mode="grayscale",
    batch_size=32,
    image_size=(50, 50),
    shuffle=True,
    seed=123,
    validation_split=0.1,
    subset='validation',
    interpolation="bilinear",
    follow_links=False,
)

Im just testing out how it works with this code from tensorflow:

import matplotlib.pyplot as plt

class_names = dataset_train.class_names

plt.figure(figsize=(10, 10))
for images, labels in dataset_train.take(1):
  for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(images[i].numpy().astype("uint8"))
    plt.title(class_names[labels[i]])
    plt.axis("off")

When i test it i get this error: "TypeError: only integer scalar arrays can be converted to a scalar index." Tried to use

class_names=np.array(dataset_train.class_names)

But then i got this error: "IndexError: only integers, slices (:), ellipsis (...), numpy.newaxis (None) and integer or boolean arrays are valid indices"

Also im not quite sure how the take(1) argument works in tf.

Upvotes: 1

Views: 483

Answers (1)

Mikhail Golubitsky
Mikhail Golubitsky

Reputation: 698

I just experienced this error in a similar situation. For me it occurred on the line

plt.title(class_names[labels[i]])

because we are using label_mode="categorical", which returns one-hot encoded labels as a list of length equal to the number of classes.

Therefore, we cannot index into class_names with a list.

Upvotes: 1

Related Questions