Reputation: 1353
I'm trying to work with the Cifar-10 dataset using Tensorflow as this,
ds, ds_info = tfds.load('cifar10', with_info=True,
split='train')
Now what I'm trying is understand it the best as possible. I know (reading ds_info
) that I have the following data,
FeaturesDict({
'id': Text(shape=(), dtype=tf.string),
'image': Image(shape=(32, 32, 3), dtype=tf.uint8),
'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10),
})
Since now, I know how many classes it has, and how to print they.
print(ds_info.features['label'].num_classes)
print(ds_info.features['label'].names)
But how I know the identifier associated to each class? I think maybe I have to use the id
feature but I'm not sure how to do it.
Upvotes: 1
Views: 1303
Reputation: 11631
You can use the str2int
method of the ClassLabel object:
>>> for name in ds_info.features['label'].names:
print(name, ds_info.features['label'].str2int(name))
airplane 0
automobile 1
bird 2
cat 3
deer 4
dog 5
frog 6
horse 7
ship 8
truck 9
Upvotes: 2