ztsv-av
ztsv-av

Reputation: 109

Does model.fit() reset metrics after each epoch? How to reset metrics manually?

As far as I understand, model.fit(epochs=NUM_EPOCHS) does not reset metrics for each epoch. My code for metrics and model.fit() looks like this (simplified):

import tensorflow as tf
from tensorflow.keras import applications

NUM_CLASSES = 4
INPUT_SHAPE = (256, 256, 3)
MODELS = {
    'DenseNet121': applications.DenseNet121,
    'DenseNet169': applications.DenseNet169
}
REDUCE_LR_PATIENCE = 2
REDUCE_LR_FACTOR = 0.7
EARLY_STOPPING_PATIENCE = 4


for modelName, model in MODELS.items():

    loadedModel = model(include_top=False, weights='imagenet',
                        pooling='avg', input_shape=INPUT_SHAPE)

    sequentialModel = tf.keras.models.Sequential()
    sequentialModel.add(loadedModel)
    sequentialModel.add(tf.keras.layers.Dense(NUM_CLASSES, activation='softmax'))

    aucCurve = tf.keras.metrics.AUC(curve = 'ROC', multi_label = True)
    categoricalAccuracy = tf.keras.metrics.CategoricalAccuracy()
    F1Score  = tfa.metrics.F1Score(num_classes = NUM_CLASSES, average = 'macro', threshold = None)
    metrics = [aucCurve, categoricalAccuracy, F1Score]

    sequentialModel.compile(metrics=metrics)

    callbacks = [
    tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', patience=REDUCE_LR_PATIENCE, verbose=1, factor=REDUCE_LR_FACTOR),
    tf.keras.callbacks.EarlyStopping(monitor='val_loss', verbose=1, patience=EARLY_STOPPING_PATIENCE),
    tf.keras.callbacks.ModelCheckpoint(filepath=modelName + '_epoch-{epoch:02d}.h5', monitor='val_loss', save_best_only=False, verbose=1),
    tf.keras.callbacks.CSVLogger(modelName + '_training.csv')]

    sequentialModel.fit(epochs=NUM_EPOCHS)

Perhaps I can reset metrics by doing a for loop in range of NUM_EPOCHS and initialize the metrics in a for loop, but I am not sure if it is a good solution. Also, I have ModelCheckpoint and CSVLogger callbacks, which require an epoch number from model.fit(), so it won't really work if I do a for loop.

Do you have any suggestions on how to reset metrics for each epoch? Is doing a for loop in range of NUM_EPOCHS the only solution here? Thank you.

Upvotes: 3

Views: 1117

Answers (2)

Maciej Skorski
Maciej Skorski

Reputation: 3354

The behaviour is controlled by the method reset_state.

Usually, it looks like

def reset_state(self):
    # Reset the metric state at the start of each epoch.
    self.my_state_variable.assign(0.0)

But can be defined differently if needed (e.g. in your own metrics when subclassing).

Upvotes: 0

mujjiga
mujjiga

Reputation: 16896

No, metrics are calculated per epoch. They are not averaged over the epochs but they are rather averaged over the batches per epoch. You see that the metrics keep improving epoch after epoch because your model is getting trained.

Upvotes: 4

Related Questions