Lleims
Lleims

Reputation: 1353

Get corresponding class id given class name from a dataset

I'm trying to work with the Cifar-10 dataset using Tensorflow as this,

ds, ds_info = tfds.load('cifar10', with_info=True,
               split='train')

Now what I'm trying is understand it the best as possible. I know (reading ds_info) that I have the following data,

FeaturesDict({
    'id': Text(shape=(), dtype=tf.string),
    'image': Image(shape=(32, 32, 3), dtype=tf.uint8),
    'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10),
})

Since now, I know how many classes it has, and how to print they.

print(ds_info.features['label'].num_classes)
print(ds_info.features['label'].names)

But how I know the identifier associated to each class? I think maybe I have to use the id feature but I'm not sure how to do it.

Upvotes: 1

Views: 1303

Answers (1)

Lescurel
Lescurel

Reputation: 11631

You can use the str2int method of the ClassLabel object:

>>> for name in ds_info.features['label'].names:
        print(name, ds_info.features['label'].str2int(name))
airplane 0
automobile 1
bird 2
cat 3
deer 4
dog 5
frog 6
horse 7
ship 8
truck 9

Upvotes: 2

Related Questions