Reputation: 3487
I'm trying to understand how tf.data.Dataset works.
It says on the documentation that take returns a dataset with a certain amount of elements from that dataset. You can then iterate over a single sample (in this case a batch):
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds
# Construct a tf.data.Dataset
ds = tfds.load('mnist', split='train', shuffle_files=True)
# Build your input pipeline
ds = ds.shuffle(1024).batch(32).prefetch(tf.data.experimental.AUTOTUNE)
single_batch_dataset = ds.take(1)
for example in single_batch_dataset:
image, label = example["image"], example["label"]
print(label)
# ...
Outputs:
tf.Tensor([2 0 6 6 8 8 6 0 3 4 8 7 5 2 5 7 8 7 1 1 1 8 6 4 0 4 3 2 4 2 1 9], shape=(32,), dtype=int64)
However, iterating over it again, gives different labels: (continuation of last code)
for example in single_batch_dataset:
image, label = example["image"], example["label"]
print(label)
for example in single_batch_dataset:
image, label = example["image"], example["label"]
print(label)
Outputs:
tf.Tensor([7 3 5 6 3 1 7 9 6 1 9 3 9 8 6 7 7 1 9 7 5 2 0 7 8 1 7 8 7 0 5 0], shape=(32,), dtype=int64)
tf.Tensor([1 3 6 1 8 8 0 4 1 3 2 9 5 3 8 7 4 2 1 8 1 0 8 5 4 5 6 7 3 4 4 1], shape=(32,), dtype=int64)
Shouldn't the labels be the same, given that the dataset is the same?
Upvotes: 1
Views: 1209
Reputation: 19310
This is because the data files are shuffled and the dataset is shuffled with dataset.shuffle()
.
With dataset.shuffle()
, the data will be shuffled in a different way on each iteration by default.
One can remove shuffle_files=True
and set the argument reshuffle_each_iteration=False
to prevent reshuffling on different iterations.
The .take()
function does not imply determinism. It will just take N items from the dataset in whichever order the dataset gives them.
# Construct a tf.data.Dataset
ds = tfds.load('mnist', split='train', shuffle_files=False)
# Build your input pipeline
ds = ds.shuffle(1024, reshuffle_each_iteration=False).batch(32).prefetch(tf.data.experimental.AUTOTUNE)
single_batch_dataset = ds.take(1)
for example in single_batch_dataset:
image, label = example["image"], example["label"]
print(label)
for example in single_batch_dataset:
image, label = example["image"], example["label"]
print(label)
Output:
tf.Tensor([4 6 8 5 1 4 5 8 1 4 6 6 8 6 6 9 4 2 3 0 5 9 2 1 3 1 8 6 4 4 7 1], shape=(32,), dtype=int64)
tf.Tensor([4 6 8 5 1 4 5 8 1 4 6 6 8 6 6 9 4 2 3 0 5 9 2 1 3 1 8 6 4 4 7 1], shape=(32,), dtype=int64)
Upvotes: 2