Francesco
Francesco

Reputation: 91

How do I modify train_step in order to support a validation set when called in model.fit()?

I'm following this Keras tutorial that explain how to write your custom train_step() function while still being able to call model.fit() in order to train your model:

https://keras.io/guides/customizing_what_happens_in_fit/

model.fit() should be able to support validation_data but I can't understand where to write code that compute custom metrics and custom losses for validation_data. I've decided to write a custom loop but I would like to use fit.

Any ideas?

Upvotes: 3

Views: 1237

Answers (1)

Francesco
Francesco

Reputation: 91

I completely missed the paragraph of the guide that mentions the the function test_step():

def test_step(self, data):
    # Unpack the data
    x, y = data
    # Compute predictions
    y_pred = self(x, training=False)
    # Updates the metrics tracking the loss
    self.compiled_loss(y, y_pred, regularization_losses=self.losses)
    # Update the metrics.
    self.compiled_metrics.update_state(y, y_pred)
    # Return a dict mapping metric names to current value.
    # Note that it will include the loss (tracked in self.metrics).
    return {m.name: m.result() for m in self.metrics}

Upvotes: 3

Related Questions