Reputation: 24581
I was wondering how to enforce the use of batches with a fixed number of samples when using Dataset
.
For example,
import numpy as np
import tensorflow as tf
dataset = tf.data.Dataset.range(101).batch(10)
iterator = dataset.make_one_shot_iterator()
batch = iterator.get_next()
sess = tf.InteractiveSession()
try:
while True:
print(batch.eval().shape)
except tf.errors.OutOfRangeError:
pass
In this toy example, the data has a total 101 samples and I ask batches of 10 samples. When iterating, the last batch has a size of 1, which is what I want to avoid.
In the former (queue-based) API, tf.train.batch
has a allow_smaller_final_batch
argument that is set to False
by default. I want to reproduce this behavior with Dataset
.
I suppose I could use Dataset.filter
:
dataset = tf.data.Dataset.range(101).batch(10)
.filter(lambda x: tf.equal(tf.shape(x)[0], 10))
but surely there should be some build-in way to do this?
Upvotes: 5
Views: 4282
Reputation: 168
For tensorflow>=2.0.0
, you can use the drop_remainder
argument to method batch
of tf.data.Dataset
as :
dataset = tf.data.Dataset.batch(BATCH_SIZE, drop_remainder=True)
drop_remainder
argument sets if the last batch is dropped in the case it has fewer than BATCH_SIZE
elements. The default value is False.
I hope this helps readers in 2019+
Upvotes: 7
Reputation: 126154
You can use tf.contrib.data.batch_and_drop_remainder(batch_size)
to do this:
dataset = tf.data.Dataset.range(101).apply(
tf.contrib.data.batch_and_drop_remainder(10))
Upvotes: 3