Reputation: 3951
I understand you can assign a batch size to a Dataset and return a new dataset object. Is there an API to interrogate the batch size given a dataset object?
I am trying to find the calls at:
https://www.tensorflow.org/api_docs/python/tf/data/Dataset
Upvotes: 3
Views: 7264
Reputation: 8585
In Tensorflow 1.* access batch_size
via dataset._dataset._batch_size
:
import tensorflow as tf
import numpy as np
print(tf.__version__) # 1.14.0
dataset = tf.data.Dataset.from_tensor_slices(np.random.randint(0, 2, 100)).batch(10)
with tf.compat.v1.Session() as sess:
batch_size = sess.run(dataset._dataset._batch_size)
print(batch_size) # 10
In Tensorflow 2 you can access via dataset._batch_size
:
import tensorflow as tf
import numpy as np
print(tf.__version__) # 2.0.1
dataset = tf.data.Dataset.from_tensor_slices(np.random.randint(0, 2, 100)).batch(10)
batch_size = dataset._batch_size.numpy()
print(batch_size) # 10
Upvotes: 1
Reputation: 435
when you call the .batch(32)
method , it returns an tensorflow.python.data.ops.dataset_ops.BatchDataset
object. As documented in Tensorflow Documentation This kind of object has private attribute called ._batch_size
which contain a tensor of batch_size.
In tensorflow 2.X you need just call .numpy()
method of this tensor to convert it to numpy.int64
type.
In tensorflow 1.X you need to cal .eval()
method.
Upvotes: 3
Reputation: 1377
I do not know if you can just get it as an attribute, but you could just iterate through the dataset once and print the shape:
# create a simple tf.data.Dataset with batchsize 3
import tensorflow as tf
f = tf.data.Dataset.range(10).batch(3) # Dataset with batch_size 3
# iterating once
for one_batch in f:
print('batch size:', one_batch.shape[0])
break
If you know your dataset has targets/labels as well, you have to iterate as follows:
# iterating once
for one_batch_x, one_batch_y in f:
print('batch size:', one_batch_x.shape[0])
break
In both cases, it will print:
batch size: 3
Upvotes: 1