theastronomist
theastronomist

Reputation: 1056

KeyError: 'Failed to format this callback filepath: Reason: \'lr\''

I recently switched form Tensorflow 2.2.0 to 2.4.1 and now I have a problem with ModelCheckpoint callback path. This code works fine if I use an environment with tf 2.2 but get an error when I use tf 2.4.1.

checkpoint_filepath = 'path_to/temp_checkpoints/model/epoch-{epoch}_loss-{lr:.2e}_loss-{val_loss:.3e}'
checkpoint = ModelCheckpoint(checkpoint_filepath, monitor='val_loss')

history = model.fit(training_data, training_data,
                    epochs=10,
                    batch_size=32,
                    shuffle=True,
                    validation_data=(validation_data, validation_data),
                    verbose=verbose, callbacks=[checkpoint])

Error:

KeyError: 'Failed to format this callback filepath: "path_to/temp_checkpoints/model/epoch-{epoch}_loss-{lr:.2e}_loss-{val_loss:.3e}". Reason: 'lr''

Upvotes: 0

Views: 1328

Answers (1)

Kaveh
Kaveh

Reputation: 4960

In ModelCheckpoint, formatted name of filepath argument, can only be contain: epoch + keys in logs after epoch ends.

You can see available keys in logs like this:

class CustomCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        keys = list(logs.keys())
        print("Log keys: {}".format(keys))

model.fit(..., callbacks=[CustomCallback()])

If you run code above, you will see something like this:

Log keys: ['loss', 'mean_absolute_error', 'val_loss', 'val_mean_absolute_error']

Which shows you available keys you can use (plus epoch) and lr is not available for you (You have used 3 keys: epoch, lr and val_loss in filepath name).


Solution:

You can add learning rate to logs yourself:

import tensorflow.keras.backend as K
class CustomCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        logs.update({'lr': K.eval(self.model.optimizer.lr)})
        keys = list(logs.keys())
        print("Log keys: {}".format(keys)) #you will see now `lr` available

checkpoint_filepath = 'path_to/temp_checkpoints/model/epoch-{epoch}_loss-{lr:.2e}_loss-{val_loss:.3e}'
checkpoint = ModelCheckpoint(checkpoint_filepath, monitor='val_loss')

history = model.fit(training_data, training_data,
                    epochs=10,
                    batch_size=32,
                    shuffle=True,
                    validation_data=(validation_data, validation_data),
                    verbose=verbose, callbacks=[checkpoint, CustomCallback()])

Upvotes: 1

Related Questions