mining
mining

Reputation: 3709

How to pad to fixed BATCH_SIZE in tf.data.Dataset?

I have a dataset with 11 samples. And when I choose the BATCH_SIZE be 2, the following code will have errors:

dataset = tf.contrib.data.TFRecordDataset(filenames) 
dataset = dataset.map(parser)
if shuffle:
    dataset = dataset.shuffle(buffer_size=128)
dataset = dataset.batch(batch_size)
dataset = dataset.repeat(count=1)

The problem lies in dataset = dataset.batch(batch_size), when the Dataset looped into the last batch, the remaining count of samples is just 1, so is there any way to pick randomly one from the previous visited samples and generate the last batch?

Upvotes: 4

Views: 3495

Answers (2)

Brian
Brian

Reputation: 7326

You can just set drop_remainder=True in your call to batch.

dataset = dataset.batch(batch_size, drop_remainder=True)

From the documentation:

drop_remainder: (Optional.) A tf.bool scalar tf.Tensor, representing whether the last batch should be dropped in the case its has fewer than batch_size elements; the default behavior is not to drop the smaller batch.

Upvotes: 3

Olivier Moindrot
Olivier Moindrot

Reputation: 28218

@mining proposes a solution by padding the filenames.

Another solution is to use tf.contrib.data.batch_and_drop_remainder. This will batch the data with a fixed batch size and drop the last smaller batch.

In your examples, with 11 inputs and a batch size of 2, this would yield 5 batches of 2 elements.

Here is the example from the documentation:

dataset = tf.data.Dataset.range(11)
batched = dataset.apply(tf.contrib.data.batch_and_drop_remainder(2))

Upvotes: 7

Related Questions