Reputation: 241
I am trying to understand how to work with tensorflow datasets, tfds.
The dataset is a directory of this kind
-dataset
-train
-class_name1
-files...
-class_name2
-files...
-val
-class_name1
-files...
-class_name2
-files...
-test
-class_name1
-files...
-class_name2
-files...
Here is some code:
import tensorflow_datasets as tfds
builder = tfds.ImageFolder('/content/dataset')
train_ds, val_ds, test_ds = builder.as_dataset(split=['train', 'val', 'test'], shuffle_files=True, as_supervised=True)
print(builder.info)
Output:
tfds.core.DatasetInfo(
name='image_folder',
version=1.0.0,
description='Generic image classification dataset.',
homepage='https://www.tensorflow.org/datasets/catalog/image_folder',
features=FeaturesDict({
'image': Image(shape=(None, None, 3), dtype=tf.uint8),
'image/filename': Text(shape=(), dtype=tf.string),
'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=243),
}),
total_num_examples=73090,
splits={
'test': 5849,
'train': 58469,
'val': 8772,
},
supervised_keys=('image', 'label'),
citation="""""",
redistribution_info=,
)
When i am plotting, doing classification reports, confusion matrix etc i want to be able to use the class_names not the integer labels.
Is there some easy command that give me access to the class_names? (There are 243 classes, not 2)
Upvotes: 3
Views: 8744
Reputation: 656
This question is a bit old. Yet, providing an answer.
First, you can create an iterator on your dataset using iter()
function, for example, on your train dataset and then access each observation using next()
fucntion.
Then you can apply .int2str()
method to the labels accessed from the builder.info
object, to convert the integer label to respective class name.
For example, assuming you are plotting the first set of 32 images with corresponding labels, you may write the code as follows with a for loop.
plt.figure(figsize = (15,10))
iterator = iter(train_ds)
for i in range(32):
img, label = next(iterator)
plt.subplot(4,8,i+1)
plt.imshow(img)
plt.title(builder.info.features['label'].int2str(label))
plt.tight_layout()
plt.show()
Hope this helps.
Upvotes: 0
Reputation: 31
Your initial builder:
builder = tfds.ImageFolder(path_to_data)
About extracting class names. they are kindda hidden. Not available inside train_ds.class_names with this builder.
You have to use directly your builder instead.
Getting builder.info
(https://www.tensorflow.org/datasets/api_docs/python/tfds/core/DatasetInfo)
Then .features
(https://www.tensorflow.org/datasets/api_docs/python/tfds/features/FeaturesDict)
And then you have a ClassLabel
(https://www.tensorflow.org/datasets/api_docs/python/tfds/features/ClassLabel) inside 'label' key of FeaturesDict,
print(builder.info.features['label'])
ClassLabel(shape=(), dtype=tf.int64, num_classes=99)
and inside it, you got a names
args/attribute describes as:
names: list< str >, string names for the integer classes. The order in which the names are provided is kept.
So, at the end you got your answer directly by:
class_names = builder.info.features['label'].names
Upvotes: 3
Reputation: 1194
You have access to train_ds, val_ds, test_ds
, and they have class_names
, which you could access using the label as an index, as mentioned by fpierrem
Here's am example of a plot:
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 10))
for images, labels in train_ds.take(1):
for i in range(9):
ax = plt.subplot(3, 3, i + 1)
plt.imshow(images[i].numpy().astype("uint8"))
plt.title(train_ds.class_names[labels[i]])
plt.axis("off")
plt.show()
see. https://www.tensorflow.org/tutorials/keras/classification#import_the_fashion_mnist_dataset and https://www.tensorflow.org/tutorials/load_data/images#visualize_the_data
Upvotes: -1