Wei
Wei

Reputation: 341

Tensorflow dataset with partial shuffle

I am playing with TensorFlow's dataset API, and I am confused by the shuffle() method, according to the docs:

The Dataset.shuffle() transformation randomly shuffles the input dataset using a similar algorithm to tf.RandomShuffleQueue: it maintains a fixed-size buffer and chooses the next element uniformly at random from that buffer.

If I only 'partially' shuffle my dataset (e.g. buffer_size <= no. of elements), I'd expect only the first buffer_size elements will be shuffled, however this is not the case, see example:

dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4,5,6,7,8])
                         .shuffle(buffer_size=4, seed=42)
                         .batch(2)
iter = dataset.make_initializable_iterator() # create the iterator
el = iter.get_next()
with tf.Session() as sess:
    sess.run(iter.initializer) 
    print('batch:', sess.run(el))

output:

batch: [2 5]

why is 5 here? as the buffer size is only 4? the first 2 elements should be within 1~4 right? what am I missing here?

Thanks

Upvotes: 1

Views: 333

Answers (1)

iga
iga

Reputation: 3633

The short answer is that the shuffle buffer can be replenished at any time, including in the middle of creating a batch.

Here is how your observation could have happened:

  • Dataset reads the first 4 elements from your data. The shuffle buffer now contains [1, 2, 3, 4]
  • You request two elements (via get_next() on a dataset that creates batches of 2)
  • The shuffle dataset picks 2 and reads the next element into the shuffle buffer, which now contains [1, 3, 4, 5].
  • The shuffle dataset picks 5 from the buffer.
  • Your batch of [2, 5] is returned.

Upvotes: 4

Related Questions