Nick Skywalker
Nick Skywalker

Reputation: 1093

Should we apply repeat, batch shuffle to tf.data.Dataset when passing it to fit function?

I still don't after having read documentation about tf.keras.Model.fit and tf.data.Dataset, when passing tf.data.Dataset to fit function, should I call repeat and batch on the dataset object or should I provide the batch_size and epochs arguments to fit instead? or both? Should I apply the same treatment to the validation set?

And while I'm here, can I shuffle the dataset before the fit? (seems like it's an obvious yes) If so, before, after calling Dataset.batch and Dataset.repeat (if calling them)?

Edit: When using batch_size argument, and without having called Dataset.batch(batch_size) previously, I am getting the following error:

ValueError: The `batch_size` argument must not be specified for the given input type.
Received input: <MapDataset shapes: ((<unknown>, <unknown>, <unknown>, <unknown>), (<unknown>, <unknown>, <unknown>)), 
types: ((tf.float32, tf.float32, tf.float32, tf.float32), (tf.float32, tf.float32, tf.float32))>, 
batch_size: 1

Thanks

Upvotes: 0

Views: 1679

Answers (1)

Frederik Bode
Frederik Bode

Reputation: 2744

There's different ways to do what you want here, but the one I always use is:

batch_size = 32
ds = tf.Dataset()
ds = ds.shuffle(len_ds)
train_ds = ds.take(0.8*len_ds)
train_ds = train_ds.repeat().batch(batch_size)
validation_ds = ds.skip(0.8*len_ds)
validation_ds = train_ds.repeat().batch(batch_size)
model.fit(train_ds,
          steps_per_epoch = len_train_ds // batch_size,
          validation_data = validation_ds,
          validation_steps = len_validation_ds // batch_size,
          epochs = 5)

This way you have access to all the variables after model fitting as well, for example if you want to visualize the validation set, you can. This is not really possible with validation_split. If you remove .batch(batch_size), you should remove the // batch_sizes, but I would leave them, as it clearer what is happening now.

You always have to provide epochs.

Calculating the length of your train/validation sets requires you to loop over them:

len_train_ds = 0
for i in train_ds:
  len_train_ds += 1

if in tf.Dataset form.

Upvotes: 3

Related Questions