JKnecht
JKnecht

Reputation: 241

How to get the class names in a tensorflow dataset?

I am trying to understand how to work with tensorflow datasets, tfds.

The dataset is a directory of this kind

-dataset
  -train
    -class_name1
      -files...
    -class_name2
      -files...
   -val
    -class_name1
      -files...
    -class_name2
      -files...
  -test
    -class_name1
      -files...
    -class_name2
      -files...

Here is some code:

import tensorflow_datasets as tfds
builder = tfds.ImageFolder('/content/dataset')
train_ds, val_ds, test_ds = builder.as_dataset(split=['train', 'val', 'test'], shuffle_files=True, as_supervised=True)
print(builder.info)
Output:
tfds.core.DatasetInfo(
    name='image_folder',
    version=1.0.0,
    description='Generic image classification dataset.',
    homepage='https://www.tensorflow.org/datasets/catalog/image_folder',
    features=FeaturesDict({
        'image': Image(shape=(None, None, 3), dtype=tf.uint8),
        'image/filename': Text(shape=(), dtype=tf.string),
        'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=243),
    }),
    total_num_examples=73090,
    splits={
        'test': 5849,
        'train': 58469,
        'val': 8772,
    },
    supervised_keys=('image', 'label'),
    citation="""""",
    redistribution_info=,
)

When i am plotting, doing classification reports, confusion matrix etc i want to be able to use the class_names not the integer labels.

Is there some easy command that give me access to the class_names? (There are 243 classes, not 2)

Upvotes: 3

Views: 8744

Answers (3)

Srinivas
Srinivas

Reputation: 656

This question is a bit old. Yet, providing an answer.

First, you can create an iterator on your dataset using iter() function, for example, on your train dataset and then access each observation using next() fucntion.

Then you can apply .int2str() method to the labels accessed from the builder.info object, to convert the integer label to respective class name.

For example, assuming you are plotting the first set of 32 images with corresponding labels, you may write the code as follows with a for loop.

plt.figure(figsize = (15,10))
iterator =  iter(train_ds)
for i in range(32):
  img, label = next(iterator)
  plt.subplot(4,8,i+1)
  plt.imshow(img)
  plt.title(builder.info.features['label'].int2str(label))
plt.tight_layout()
plt.show()

Hope this helps.

Upvotes: 0

A MARTIN
A MARTIN

Reputation: 31

Your initial builder:

builder = tfds.ImageFolder(path_to_data)

About extracting class names. they are kindda hidden. Not available inside train_ds.class_names with this builder.

You have to use directly your builder instead.

Getting builder.info (https://www.tensorflow.org/datasets/api_docs/python/tfds/core/DatasetInfo)

Then .features (https://www.tensorflow.org/datasets/api_docs/python/tfds/features/FeaturesDict)

And then you have a ClassLabel (https://www.tensorflow.org/datasets/api_docs/python/tfds/features/ClassLabel) inside 'label' key of FeaturesDict,

print(builder.info.features['label'])
ClassLabel(shape=(), dtype=tf.int64, num_classes=99)

and inside it, you got a names args/attribute describes as:

names: list< str >, string names for the integer classes. The order in which the names are provided is kept.

So, at the end you got your answer directly by:

class_names = builder.info.features['label'].names

Upvotes: 3

Nicolae Natea
Nicolae Natea

Reputation: 1194

You have access to train_ds, val_ds, test_ds, and they have class_names, which you could access using the label as an index, as mentioned by fpierrem

Here's am example of a plot:

import matplotlib.pyplot as plt

plt.figure(figsize=(10, 10))
for images, labels in train_ds.take(1):
    for i in range(9):
        ax = plt.subplot(3, 3, i + 1)
        plt.imshow(images[i].numpy().astype("uint8"))
        plt.title(train_ds.class_names[labels[i]])
        plt.axis("off")
    
plt.show()

see. https://www.tensorflow.org/tutorials/keras/classification#import_the_fashion_mnist_dataset and https://www.tensorflow.org/tutorials/load_data/images#visualize_the_data

Upvotes: -1

Related Questions