Reputation: 36604
When I use .batch()
as the last operation on a tf.data.Dataset
, I can get the batch size like this:
train_ds._batch_size.numpy()
For instance with this dataset:
import tensorflow as tf
(x_train, y_train), _ = tf.keras.datasets.fashion_mnist.load_data()
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(8)
train_ds._batch_size.numpy()
8
However, when I use .batch().prefetch(1)
, I cannot get the batch size:
AttributeError: 'PrefetchDataset' object has no attribute '_batch_size'
Upvotes: 0
Views: 1133
Reputation: 36604
With any type of Tensorflow Dataset, you can access any dataset before the chained methods with ._input_dataset
:
train_ds._input_dataset
<BatchDataset shapes: ((None, 28, 28), (None,)), types: (tf.uint8, tf.uint8)>
Now that you have accessed the BatchDataset
object, you can get the batch size the same way:
train_ds._input_dataset._batch_size.numpy()
8
The same would work for several transformations, e.g. .batch().prefetch().cache()
:
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_ds = train_ds.batch(8).prefetch(1).cache()
train_ds._input_dataset._input_dataset._batch_size.numpy()
Upvotes: 3
Reputation: 601
If I understand your mean correctly, you want to get the batch size after you use the prefetch() method.
you can't get the batch size, however you can get the buffer size if that is what you want.
prefetch() method will give you a PrefetchDataset object, according to the source code of the PrefetchDataset Class.
batch(8).prefetch(1)._buffer_size()
will do the job.
Upvotes: 0