kawingkelvin
kawingkelvin

Reputation: 3951

Is there a way to find the batch size for a tf.data.Dataset

I understand you can assign a batch size to a Dataset and return a new dataset object. Is there an API to interrogate the batch size given a dataset object?

I am trying to find the calls at:

https://www.tensorflow.org/api_docs/python/tf/data/Dataset

Upvotes: 3

Views: 7264

Answers (3)

Vlad
Vlad

Reputation: 8585

In Tensorflow 1.* access batch_size via dataset._dataset._batch_size:

import tensorflow as tf
import numpy as np
print(tf.__version__) # 1.14.0

dataset = tf.data.Dataset.from_tensor_slices(np.random.randint(0, 2, 100)).batch(10)

with tf.compat.v1.Session() as sess:
    batch_size = sess.run(dataset._dataset._batch_size)
    print(batch_size) # 10

In Tensorflow 2 you can access via dataset._batch_size:

import tensorflow as tf
import numpy as np
print(tf.__version__) # 2.0.1

dataset = tf.data.Dataset.from_tensor_slices(np.random.randint(0, 2, 100)).batch(10)

batch_size = dataset._batch_size.numpy()

print(batch_size) # 10

Upvotes: 1

Tolik
Tolik

Reputation: 435

when you call the .batch(32) method , it returns an tensorflow.python.data.ops.dataset_ops.BatchDataset object. As documented in Tensorflow Documentation This kind of object has private attribute called ._batch_size which contain a tensor of batch_size.

In tensorflow 2.X you need just call .numpy() method of this tensor to convert it to numpy.int64 type. In tensorflow 1.X you need to cal .eval() method.

Upvotes: 3

Bashir Kazimi
Bashir Kazimi

Reputation: 1377

I do not know if you can just get it as an attribute, but you could just iterate through the dataset once and print the shape:

# create a simple tf.data.Dataset with batchsize 3
import tensorflow as tf 
f = tf.data.Dataset.range(10).batch(3) # Dataset with batch_size 3

# iterating once
for one_batch in f:
    print('batch size:', one_batch.shape[0])
    break

If you know your dataset has targets/labels as well, you have to iterate as follows:

# iterating once
for one_batch_x, one_batch_y in f:
    print('batch size:', one_batch_x.shape[0])
    break

In both cases, it will print:

batch size:  3

Upvotes: 1

Related Questions