Reputation: 337
I have a dataset loaded from a directory using this API
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
data_dir,
validation_split=0.3,
subset="validation",
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size)
And I want to change the data type and to make the training faster
I tried this but it didn't work
for image_batch, labels_batch in train_ds:
image_batch = tf.cast(image_batch,tf.int16)
Upvotes: 1
Views: 2867
Reputation: 702
Just apply map method for your dataset(s):
val_ds.map(lambda x, y: (tf.cast(x, tf.int16), y))
Upvotes: 5