Yafaa
Yafaa

Reputation: 337

How to change the dtype of data in tf.data.Dataset?

I have a dataset loaded from a directory using this API

val_ds = tf.keras.preprocessing.image_dataset_from_directory(
  data_dir,
  validation_split=0.3,
  subset="validation",
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size)

And I want to change the data type and to make the training faster

I tried this but it didn't work

for image_batch, labels_batch in train_ds:
  image_batch = tf.cast(image_batch,tf.int16)

Upvotes: 1

Views: 2867

Answers (1)

PermanentPon
PermanentPon

Reputation: 702

Just apply map method for your dataset(s):

val_ds.map(lambda x, y: (tf.cast(x, tf.int16), y))

Upvotes: 5

Related Questions