Bersan
Bersan

Reputation: 3487

Why iterations over the same tf.data.Dataset give different data each iteration?

I'm trying to understand how tf.data.Dataset works.

It says on the documentation that take returns a dataset with a certain amount of elements from that dataset. You can then iterate over a single sample (in this case a batch):

import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds

# Construct a tf.data.Dataset
ds = tfds.load('mnist', split='train', shuffle_files=True)

# Build your input pipeline
ds = ds.shuffle(1024).batch(32).prefetch(tf.data.experimental.AUTOTUNE)

single_batch_dataset = ds.take(1)

for example in single_batch_dataset:
  image, label = example["image"], example["label"]
  print(label)
# ...

Outputs:

tf.Tensor([2 0 6 6 8 8 6 0 3 4 8 7 5 2 5 7 8 7 1 1 1 8 6 4 0 4 3 2 4 2 1 9], shape=(32,), dtype=int64)

However, iterating over it again, gives different labels: (continuation of last code)

for example in single_batch_dataset:
  image, label = example["image"], example["label"]
  print(label)

for example in single_batch_dataset:
  image, label = example["image"], example["label"]
  print(label)

Outputs:

tf.Tensor([7 3 5 6 3 1 7 9 6 1 9 3 9 8 6 7 7 1 9 7 5 2 0 7 8 1 7 8 7 0 5 0], shape=(32,), dtype=int64)
tf.Tensor([1 3 6 1 8 8 0 4 1 3 2 9 5 3 8 7 4 2 1 8 1 0 8 5 4 5 6 7 3 4 4 1], shape=(32,), dtype=int64)

Shouldn't the labels be the same, given that the dataset is the same?

Upvotes: 1

Views: 1209

Answers (1)

jkr
jkr

Reputation: 19310

This is because the data files are shuffled and the dataset is shuffled with dataset.shuffle().

With dataset.shuffle(), the data will be shuffled in a different way on each iteration by default.

One can remove shuffle_files=True and set the argument reshuffle_each_iteration=False to prevent reshuffling on different iterations.

The .take() function does not imply determinism. It will just take N items from the dataset in whichever order the dataset gives them.

# Construct a tf.data.Dataset
ds = tfds.load('mnist', split='train', shuffle_files=False)

# Build your input pipeline
ds = ds.shuffle(1024, reshuffle_each_iteration=False).batch(32).prefetch(tf.data.experimental.AUTOTUNE)

single_batch_dataset = ds.take(1)

for example in single_batch_dataset:
    image, label = example["image"], example["label"]
    print(label)
    
for example in single_batch_dataset:
    image, label = example["image"], example["label"]
    print(label)

Output:

tf.Tensor([4 6 8 5 1 4 5 8 1 4 6 6 8 6 6 9 4 2 3 0 5 9 2 1 3 1 8 6 4 4 7 1], shape=(32,), dtype=int64)
tf.Tensor([4 6 8 5 1 4 5 8 1 4 6 6 8 6 6 9 4 2 3 0 5 9 2 1 3 1 8 6 4 4 7 1], shape=(32,), dtype=int64)

Upvotes: 2

Related Questions