Reputation: 341
I am playing with TensorFlow's dataset API, and I am confused by the shuffle() method, according to the docs:
The Dataset.shuffle() transformation randomly shuffles the input dataset using a similar algorithm to tf.RandomShuffleQueue: it maintains a fixed-size buffer and chooses the next element uniformly at random from that buffer.
If I only 'partially' shuffle my dataset (e.g. buffer_size <= no. of elements), I'd expect only the first buffer_size elements will be shuffled, however this is not the case, see example:
dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4,5,6,7,8])
.shuffle(buffer_size=4, seed=42)
.batch(2)
iter = dataset.make_initializable_iterator() # create the iterator
el = iter.get_next()
with tf.Session() as sess:
sess.run(iter.initializer)
print('batch:', sess.run(el))
output:
batch: [2 5]
why is 5 here? as the buffer size is only 4? the first 2 elements should be within 1~4 right? what am I missing here?
Thanks
Upvotes: 1
Views: 333
Reputation: 3633
The short answer is that the shuffle buffer can be replenished at any time, including in the middle of creating a batch.
Here is how your observation could have happened:
Upvotes: 4