Reputation: 31
I am using tensorflow 2.
When using the Model.fit()
method with a tf.data.Dataset
, the argument 'batch_size
' is ignored. Thus to train my model on batches, I have to first change my dataset of samples into a dataset of batches of samples by calling tf.data.Dataset.batch(batch_size)
.
Then, after reading the documentation, I don't understand clearly how the .fit()
method will shuffle my dataset at each epoch.
Since my dataset is a dataset of batches, will it shuffle the batches among each other (the batches remain unchanged) ? Or will it shuffle all the samples and then regroup them into new batches (which is the desired behaviour) ?
Thanks a lot for your help.
Upvotes: 3
Views: 1821
Reputation: 11651
The shuffle
parameter has no effect on the fit
function when using the tf.data.Dataset
API.
If we read the documentation (emphasis is mine) :
shuffle: Boolean (whether to shuffle the training data before each epoch) or str (for 'batch'). This argument is ignored when x is a generator. 'batch' is a special option for dealing with the limitations of HDF5 data; it shuffles in batch-sized chunks. Has no effect when steps_per_epoch is not None.
It's not super clear, but we can have a hint that the shuffle argument will be ignored when using a tf.data.Dataset
, as it behave like a generator.
To be certain, lets dive in the code. If we look at the code of the fit
method, you will see that the data is handled by a special class, DataHandler
. Looking at the code of this class, we see that this is an Adapter class to handle different kind of data. We are interrested in the class that handle tf.data.Dataset, DatasetAdapter
, and we can see that this class does not take into account the shuffle
parameter :
def __init__(self,
x,
y=None,
sample_weights=None,
steps=None,
**kwargs):
super(DatasetAdapter, self).__init__(x, y, **kwargs)
# Note that the dataset instance is immutable, its fine to reuse the user
# provided dataset.
self._dataset = x
# The user-provided steps.
self._user_steps = steps
self._validate_args(y, sample_weights, steps)
If you want to shuffle your dataset, use the shuffle function from the tf.data.Dataset
API.
Upvotes: 1