Reputation: 55
I want to change the dtype of one element of my dataset. (element shape = (32,28,28)--> this is one batch of 28 by 28 images in mnist dataset)
So i ran the following command: tf.cast(dataset.take(1),tf.float32)
.
the type of my dataset is tensorflow.python.data.ops.dataset_ops.PrefetchDataset
It threw an error: : Attempt to convert a value (<TakeDataset shapes: (32, 28, 28), types: tf.uint8>) with an unsupported type (<class 'tensorflow.python.data.ops.dataset_ops.TakeDataset'>) to a Tensor.
So i took one element from the dataset using the code:
for batch_data in dataset:
one_element = dataset
break
and then I ran tf.cast(one_element,tf.float32)
and it works.
May I know why this is happening?
Upvotes: 0
Views: 419
Reputation: 6377
tf.data.Dataset.take() returns Dataset not tensor (even when you call take(1)): https://www.tensorflow.org/api_docs/python/tf/data/Dataset
Upvotes: 1