Reputation: 6756
I came across this notebook that covers forecasting. I got it through this article.
I am confused about the 2nd and 4th line from below
train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_data = train_data.cache().shuffle(buffer_size).batch(batch_size).repeat()
val_data = tf.data.Dataset.from_tensor_slices((x_vali, y_vali))
val_data = val_data.batch(batch_size).repeat()
I understand that we are trying to shuffle our data as we dont want to feed data to our model in the serial order. On additional reading I realized that it is better to have buffer_size
same as the size of the dataset. But I am not sure what repeat
is doing in this case. Could someone explain what is being done here and what is the function of repeat
?
I also looked at this page and saw below text but still not clear.
The following methods in tf.Dataset :
repeat( count=0 ) The method repeats the dataset count number of times.
shuffle( buffer_size, seed=None, reshuffle_each_iteration=None) The method shuffles the samples in the dataset. The buffer_size is the number of samples which are randomized and returned as tf.Dataset.
batch(batch_size,drop_remainder=False) Creates batches of the dataset with batch size given as batch_size which is also the length of the batches.
Upvotes: 1
Views: 217
Reputation: 1941
The repeat call with nothing passed to the count param makes this dataset repeat infinitely.
In python terms, Datasets are a subclass of python iterables. If you have an object ds
of type tf.data.Dataset
, then you can execute iter(ds)
. If the dataset was generated by repeat()
, then it will never run out of items, i.e., it will never throw a StopIteration
exception.
In the notebook you referenced, the call to tf.keras.Model.fit() is passed an argument of 100
to the param steps_per_epoch
. This means that the dataset should be infinitely repeating, and Keras will pause training to run validation every 100 steps.
tldr: leave it in.
https://docs.python.org/3/library/exceptions.html
Upvotes: 1