Intrastellar Explorer
Intrastellar Explorer

Reputation: 2441

How to tf.cast a field within a tensorflow Dataset

I have a tf.data.Dataset that looks like this:

<BatchDataset shapes: ((None, 256, 256, 3), (None,)), types: (tf.float32, tf.int32)>

The 2nd element (1st if zero indexing) corresponds with a label. I want to cast the 2nd term (labels) to tf.uint8.

How can one use tf.cast when dealing with td.data.Dataset?


Similar Questions

How to convert tf.int64 to tf.float32? is very similar, but is not for a tf.data.Dataset.


Repro

From Image classification from scratch:

curl -O https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_5340.zip
unzip kagglecatsanddogs_5340.zip

Then in Python with tensorflow~=2.4:

import tensorflow as tf

ds = tf.keras.preprocessing.image_dataset_from_directory(
    "PetImages", batch_size=32
)
print(ds)

Upvotes: 1

Views: 876

Answers (1)

bharys
bharys

Reputation: 192

A map function may help

a = tf.data.Dataset.from_tensor_slices(np.empty((2,5,3)))
b = tf.data.Dataset.range(5, 8)
c = tf.data.Dataset.zip((a,b))
d = c.batch(1)
d
<BatchDataset shapes: ((None, 5, 3), (None,)), types: (tf.float64, tf.int64)>

# change the dtype of the 2nd arg in the batch from int64 to int8
e = d.map(lambda x,y:(x,tf.cast(y, tf.int8))) 
<MapDataset shapes: ((None, 5, 3), (None,)), types: (tf.float64, tf.int8)>

Upvotes: 4

Related Questions