Reputation: 109
As far as I understand, model.fit(epochs=NUM_EPOCHS) does not reset metrics for each epoch. My code for metrics and model.fit() looks like this (simplified):
import tensorflow as tf
from tensorflow.keras import applications
NUM_CLASSES = 4
INPUT_SHAPE = (256, 256, 3)
MODELS = {
'DenseNet121': applications.DenseNet121,
'DenseNet169': applications.DenseNet169
}
REDUCE_LR_PATIENCE = 2
REDUCE_LR_FACTOR = 0.7
EARLY_STOPPING_PATIENCE = 4
for modelName, model in MODELS.items():
loadedModel = model(include_top=False, weights='imagenet',
pooling='avg', input_shape=INPUT_SHAPE)
sequentialModel = tf.keras.models.Sequential()
sequentialModel.add(loadedModel)
sequentialModel.add(tf.keras.layers.Dense(NUM_CLASSES, activation='softmax'))
aucCurve = tf.keras.metrics.AUC(curve = 'ROC', multi_label = True)
categoricalAccuracy = tf.keras.metrics.CategoricalAccuracy()
F1Score = tfa.metrics.F1Score(num_classes = NUM_CLASSES, average = 'macro', threshold = None)
metrics = [aucCurve, categoricalAccuracy, F1Score]
sequentialModel.compile(metrics=metrics)
callbacks = [
tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', patience=REDUCE_LR_PATIENCE, verbose=1, factor=REDUCE_LR_FACTOR),
tf.keras.callbacks.EarlyStopping(monitor='val_loss', verbose=1, patience=EARLY_STOPPING_PATIENCE),
tf.keras.callbacks.ModelCheckpoint(filepath=modelName + '_epoch-{epoch:02d}.h5', monitor='val_loss', save_best_only=False, verbose=1),
tf.keras.callbacks.CSVLogger(modelName + '_training.csv')]
sequentialModel.fit(epochs=NUM_EPOCHS)
Perhaps I can reset metrics by doing a for loop in range of NUM_EPOCHS and initialize the metrics in a for loop, but I am not sure if it is a good solution. Also, I have ModelCheckpoint and CSVLogger callbacks, which require an epoch number from model.fit(), so it won't really work if I do a for loop.
Do you have any suggestions on how to reset metrics for each epoch? Is doing a for loop in range of NUM_EPOCHS the only solution here? Thank you.
Upvotes: 3
Views: 1117
Reputation: 3354
The behaviour is controlled by the method reset_state
.
Usually, it looks like
def reset_state(self):
# Reset the metric state at the start of each epoch.
self.my_state_variable.assign(0.0)
But can be defined differently if needed (e.g. in your own metrics when subclassing).
Upvotes: 0
Reputation: 16896
No, metrics are calculated per epoch. They are not averaged over the epochs but they are rather averaged over the batches per epoch. You see that the metrics keep improving epoch after epoch because your model is getting trained.
Upvotes: 4