Nicolas Gervais
Nicolas Gervais

Reputation: 36604

How do I get the batch size of a Tensorflow Prefetch/Cache Dataset?

When I use .batch() as the last operation on a tf.data.Dataset, I can get the batch size like this:

train_ds._batch_size.numpy()

For instance with this dataset:

import tensorflow as tf

(x_train, y_train), _ = tf.keras.datasets.fashion_mnist.load_data()

train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(8)

train_ds._batch_size.numpy()
8

However, when I use .batch().prefetch(1), I cannot get the batch size:

AttributeError: 'PrefetchDataset' object has no attribute '_batch_size'

Upvotes: 0

Views: 1133

Answers (2)

Nicolas Gervais
Nicolas Gervais

Reputation: 36604

With any type of Tensorflow Dataset, you can access any dataset before the chained methods with ._input_dataset:

train_ds._input_dataset
<BatchDataset shapes: ((None, 28, 28), (None,)), types: (tf.uint8, tf.uint8)>

Now that you have accessed the BatchDataset object, you can get the batch size the same way:

train_ds._input_dataset._batch_size.numpy()
8

The same would work for several transformations, e.g. .batch().prefetch().cache():

train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))

train_ds = train_ds.batch(8).prefetch(1).cache()
train_ds._input_dataset._input_dataset._batch_size.numpy()

Upvotes: 3

bluebird_lboro
bluebird_lboro

Reputation: 601

If I understand your mean correctly, you want to get the batch size after you use the prefetch() method.

you can't get the batch size, however you can get the buffer size if that is what you want.

prefetch() method will give you a PrefetchDataset object, according to the source code of the PrefetchDataset Class.

 batch(8).prefetch(1)._buffer_size()

will do the job.

Upvotes: 0

Related Questions