Reputation: 1795
I'm currently learning TensorFlow but I came across a confusion in the below code snippet:
dataset = dataset.shuffle(buffer_size = 10 * batch_size)
dataset = dataset.repeat(num_epochs).batch(batch_size)
return dataset.make_one_shot_iterator().get_next()
I know that first the dataset will hold all the data but what shuffle()
,repeat()
, and batch()
do to the dataset?
Please help me with an example and explanation.
Upvotes: 102
Views: 79361
Reputation: 17794
Combines consecutive elements of the dataset into groups (batches):
without batching
dataset = tf.data.Dataset.range(10)
for i in dataset:
print(i.numpy())
Output:
0
1
2
3
4
5
with batching
dataset = tf.data.Dataset.range(10)
for i in dataset.batch(2):
print(i.numpy())
Output:
[0 1]
[2 3]
[4 5]
Randomly shuffles the input data. According to the docs, the Dataset.shuffle()
transformation maintains a fixed-size buffer and chooses the next element uniformly at random from that buffer.
I could not understand the result of chained shuffle
and batch
when we use only part of the data in the buffer. Why do we get values greater than 19 in the first batch when we only have the first 20 values in the shuffle
?
dataset = tf.data.Dataset.range(100)
dataset = dataset.shuffle(20).batch(10)
print(next(iter(dataset)).numpy())
Output:
[ 6 3 13 18 20 21 5 0 2 15]
It looks like after the batch
has fetched a single value from the buffer, the next value (20, 21, 22...) jumps into the buffer and the batch
can select this as its next value. In this way, in the first batch, we will get 10 values ranging from 0 to 29.
According to the Tensorflow documentation, repeat
is used to iterate over a dataset in multiple epochs (epoch is a complete dataset). In other words, it simply replicates the input data.
dataset = tf.data.Dataset.range(5) # [0 1 2 3 4]
for i in dataset.repeat(2).batch(3):
print(i.numpy())
Output:
[0 1 2]
[3 4 0]
[1 2 3]
[4]
If we don't want to mix data from different epochs in one batch, we need to put repeat
after batch
.
dataset = tf.data.Dataset.range(5)
for i in dataset.batch(3).repeat(2):
print(i.numpy())
Output:
[0 1 2]
[3 4]
[0 1 2]
[3 4]
Upvotes: 2
Reputation: 4757
Update: Here is a small collaboration notebook for demonstration of this answer.
Imagine, you have a dataset: [1, 2, 3, 4, 5, 6]
, then:
How ds.shuffle() works
dataset.shuffle(buffer_size=3)
will allocate a buffer of size 3 for picking random entries. This buffer will be connected to the source dataset.
We could image it like this:
Random buffer
|
| Source dataset where all other elements live
| |
↓ ↓
[1,2,3] <= [4,5,6]
Let's assume that entry 2
was taken from the random buffer. Free space is filled by the next element from the source buffer, that is 4
:
2 <= [1,3,4] <= [5,6]
We continue reading till nothing is left:
1 <= [3,4,5] <= [6]
5 <= [3,4,6] <= []
3 <= [4,6] <= []
6 <= [4] <= []
4 <= [] <= []
How ds.repeat() works
As soon as all the entries are read from the dataset and you try to read the next element, the dataset will throw an error.
That's where ds.repeat()
comes into play. It will re-initialize the dataset, making it again like this:
[1,2,3] <= [4,5,6]
What will ds.batch() produce
The ds.batch()
will take the first batch_size
entries and make a batch out of them. So, a batch size of 3 for our example dataset will produce two batch records:
[2,1,5]
[3,6,4]
As we have a ds.repeat()
before the batch, the generation of the data will continue. But the order of the elements will be different, due to the ds.random()
. What should be taken into account is that 6
will never be present in the first batch, due to the size of the random buffer.
Upvotes: 166
Reputation: 116
An example that shows looping over epochs. Upon running this script notice the difference in
dataset_gen1
- shuffle operation produces more random outputs (this may be more useful while running machine learning experiments)dataset_gen2
- lack of shuffle operation produces elements in sequenceOther additions in this script
tf.data.experimental.sample_from_datasets
- used to combine two datasets. Note that the shuffle operation in this case shall create a buffer that samples equally from both datasets.import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # to avoid all those prints
os.environ["TF_GPU_THREAD_MODE"] = "gpu_private" # to avoid large "Kernel Launch Time"
import tensorflow as tf
if len(tf.config.list_physical_devices('GPU')):
tf.config.experimental.set_memory_growth(tf.config.list_physical_devices('GPU')[0], True)
class Augmentations:
def __init__(self):
pass
@tf.function
def filter_even(self, x):
if x % 2 == 0:
return False
else:
return True
class Dataset:
def __init__(self, aug, range_min=0, range_max=100):
self.range_min = range_min
self.range_max = range_max
self.aug = aug
def generator(self):
dataset = tf.data.Dataset.from_generator(self._generator
, output_types=(tf.float32), args=())
dataset = dataset.filter(self.aug.filter_even)
return dataset
def _generator(self):
for item in range(self.range_min, self.range_max):
yield(item)
# Can be used when you have multiple datasets that you wish to combine
class ZipDataset:
def __init__(self, datasets):
self.datasets = datasets
self.datasets_generators = []
def generator(self):
for dataset in self.datasets:
self.datasets_generators.append(dataset.generator())
return tf.data.experimental.sample_from_datasets(self.datasets_generators)
if __name__ == "__main__":
aug = Augmentations()
dataset1 = Dataset(aug, 0, 100)
dataset2 = Dataset(aug, 100, 200)
dataset = ZipDataset([dataset1, dataset2])
epochs = 2
shuffle_buffer = 10
batch_size = 4
prefetch_buffer = 5
dataset_gen1 = dataset.generator().shuffle(shuffle_buffer).batch(batch_size).prefetch(prefetch_buffer)
# dataset_gen2 = dataset.generator().batch(batch_size).prefetch(prefetch_buffer) # this will output odd elements in sequence
for epoch in range(epochs):
print ('\n ------------------ Epoch: {} ------------------'.format(epoch))
for X in dataset_gen1.repeat(1): # adding .repeat() in the loop allows you to easily control the end of the loop
print (X)
# Do some stuff at end of loop
Upvotes: 0
Reputation:
The following methods in tf.Dataset :
repeat( count=0 )
The method repeats the dataset count
number of times.shuffle( buffer_size, seed=None, reshuffle_each_iteration=None)
The method shuffles the samples in the dataset. The buffer_size
is the number of samples which are randomized and returned as tf.Dataset
.batch(batch_size,drop_remainder=False)
Creates batches of the dataset with batch size given as batch_size
which is also the length of the batches.Upvotes: 12