Reputation: 1056
I recently switched form Tensorflow 2.2.0 to 2.4.1 and now I have a problem with ModelCheckpoint
callback path. This code works fine if I use an environment with tf 2.2 but get an error when I use tf 2.4.1.
checkpoint_filepath = 'path_to/temp_checkpoints/model/epoch-{epoch}_loss-{lr:.2e}_loss-{val_loss:.3e}'
checkpoint = ModelCheckpoint(checkpoint_filepath, monitor='val_loss')
history = model.fit(training_data, training_data,
epochs=10,
batch_size=32,
shuffle=True,
validation_data=(validation_data, validation_data),
verbose=verbose, callbacks=[checkpoint])
Error:
KeyError: 'Failed to format this callback filepath: "path_to/temp_checkpoints/model/epoch-{epoch}_loss-{lr:.2e}_loss-{val_loss:.3e}". Reason: 'lr''
Upvotes: 0
Views: 1328
Reputation: 4960
In ModelCheckpoint
, formatted name of filepath
argument, can only be contain: epoch
+ keys in logs
after epoch ends.
You can see available keys in logs like this:
class CustomCallback(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs=None):
keys = list(logs.keys())
print("Log keys: {}".format(keys))
model.fit(..., callbacks=[CustomCallback()])
If you run code above, you will see something like this:
Log keys: ['loss', 'mean_absolute_error', 'val_loss', 'val_mean_absolute_error']
Which shows you available keys you can use (plus epoch
) and lr
is not available for you (You have used 3 keys: epoch
, lr
and val_loss
in filepath
name).
Solution:
You can add learning rate to logs yourself:
import tensorflow.keras.backend as K
class CustomCallback(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs=None):
logs.update({'lr': K.eval(self.model.optimizer.lr)})
keys = list(logs.keys())
print("Log keys: {}".format(keys)) #you will see now `lr` available
checkpoint_filepath = 'path_to/temp_checkpoints/model/epoch-{epoch}_loss-{lr:.2e}_loss-{val_loss:.3e}'
checkpoint = ModelCheckpoint(checkpoint_filepath, monitor='val_loss')
history = model.fit(training_data, training_data,
epochs=10,
batch_size=32,
shuffle=True,
validation_data=(validation_data, validation_data),
verbose=verbose, callbacks=[checkpoint, CustomCallback()])
Upvotes: 1