P-Gn
P-Gn

Reputation: 24581

Fixed size batches (potentially discarding last batch) using Dataset

I was wondering how to enforce the use of batches with a fixed number of samples when using Dataset.

For example,

import numpy as np
import tensorflow as tf

dataset = tf.data.Dataset.range(101).batch(10)
iterator = dataset.make_one_shot_iterator()
batch = iterator.get_next()

sess = tf.InteractiveSession()
try:
  while True:
    print(batch.eval().shape)
except tf.errors.OutOfRangeError:
  pass

In this toy example, the data has a total 101 samples and I ask batches of 10 samples. When iterating, the last batch has a size of 1, which is what I want to avoid.

In the former (queue-based) API, tf.train.batch has a allow_smaller_final_batch argument that is set to False by default. I want to reproduce this behavior with Dataset.

I suppose I could use Dataset.filter:

dataset = tf.data.Dataset.range(101).batch(10)
  .filter(lambda x: tf.equal(tf.shape(x)[0], 10))

but surely there should be some build-in way to do this?

Upvotes: 5

Views: 4282

Answers (2)

Aakash Patil
Aakash Patil

Reputation: 168

For tensorflow>=2.0.0, you can use the drop_remainder argument to method batch of tf.data.Dataset as :

dataset = tf.data.Dataset.batch(BATCH_SIZE, drop_remainder=True)

drop_remainder argument sets if the last batch is dropped in the case it has fewer than BATCH_SIZE elements. The default value is False.

I hope this helps readers in 2019+

Upvotes: 7

mrry
mrry

Reputation: 126154

You can use tf.contrib.data.batch_and_drop_remainder(batch_size) to do this:

dataset = tf.data.Dataset.range(101).apply(
    tf.contrib.data.batch_and_drop_remainder(10))

Upvotes: 3

Related Questions