Reputation: 91
I'm following this Keras tutorial that explain how to write your custom train_step() function while still being able to call model.fit() in order to train your model:
https://keras.io/guides/customizing_what_happens_in_fit/
model.fit() should be able to support validation_data but I can't understand where to write code that compute custom metrics and custom losses for validation_data. I've decided to write a custom loop but I would like to use fit.
Any ideas?
Upvotes: 3
Views: 1237
Reputation: 91
I completely missed the paragraph of the guide that mentions the the function test_step():
def test_step(self, data):
# Unpack the data
x, y = data
# Compute predictions
y_pred = self(x, training=False)
# Updates the metrics tracking the loss
self.compiled_loss(y, y_pred, regularization_losses=self.losses)
# Update the metrics.
self.compiled_metrics.update_state(y, y_pred)
# Return a dict mapping metric names to current value.
# Note that it will include the loss (tracked in self.metrics).
return {m.name: m.result() for m in self.metrics}
Upvotes: 3