Reputation: 1093
I still don't after having read documentation about tf.keras.Model.fit
and tf.data.Dataset
, when passing tf.data.Dataset
to fit function, should I call repeat
and batch
on the dataset object or should I provide the batch_size
and epochs
arguments to fit instead? or both? Should I apply the same treatment to the validation set?
And while I'm here, can I shuffle
the dataset before the fit
? (seems like it's an obvious yes)
If so, before, after calling Dataset.batch
and Dataset.repeat
(if calling them)?
Edit: When using batch_size
argument, and without having called Dataset.batch(batch_size)
previously, I am getting the following error:
ValueError: The `batch_size` argument must not be specified for the given input type.
Received input: <MapDataset shapes: ((<unknown>, <unknown>, <unknown>, <unknown>), (<unknown>, <unknown>, <unknown>)),
types: ((tf.float32, tf.float32, tf.float32, tf.float32), (tf.float32, tf.float32, tf.float32))>,
batch_size: 1
Thanks
Upvotes: 0
Views: 1679
Reputation: 2744
There's different ways to do what you want here, but the one I always use is:
batch_size = 32
ds = tf.Dataset()
ds = ds.shuffle(len_ds)
train_ds = ds.take(0.8*len_ds)
train_ds = train_ds.repeat().batch(batch_size)
validation_ds = ds.skip(0.8*len_ds)
validation_ds = train_ds.repeat().batch(batch_size)
model.fit(train_ds,
steps_per_epoch = len_train_ds // batch_size,
validation_data = validation_ds,
validation_steps = len_validation_ds // batch_size,
epochs = 5)
This way you have access to all the variables after model fitting as well, for example if you want to visualize the validation set, you can. This is not really possible with validation_split
. If you remove .batch(batch_size)
, you should remove the // batch_size
s, but I would leave them, as it clearer what is happening now.
You always have to provide epochs.
Calculating the length of your train/validation sets requires you to loop over them:
len_train_ds = 0
for i in train_ds:
len_train_ds += 1
if in tf.Dataset
form.
Upvotes: 3