Hold My Stack
Hold My Stack

Reputation: 55

Cannot change the dtype using tf.cast, when value passed is dataset.take(1)

I want to change the dtype of one element of my dataset. (element shape = (32,28,28)--> this is one batch of 28 by 28 images in mnist dataset)

So i ran the following command: tf.cast(dataset.take(1),tf.float32).

the type of my dataset is tensorflow.python.data.ops.dataset_ops.PrefetchDataset

It threw an error: : Attempt to convert a value (<TakeDataset shapes: (32, 28, 28), types: tf.uint8>) with an unsupported type (<class 'tensorflow.python.data.ops.dataset_ops.TakeDataset'>) to a Tensor.

So i took one element from the dataset using the code:

    for batch_data in dataset:
        one_element = dataset
        break

and then I ran tf.cast(one_element,tf.float32) and it works.

May I know why this is happening?

Upvotes: 0

Views: 419

Answers (1)

Andrey
Andrey

Reputation: 6377

tf.data.Dataset.take() returns Dataset not tensor (even when you call take(1)): https://www.tensorflow.org/api_docs/python/tf/data/Dataset

Upvotes: 1

Related Questions