Reputation: 3971
ModelCheckpoint can be used to save the best model based on a specific monitored metrics. So it obviously has information about the best metrics stored within its object. If you train on google colab for example, your instance can be killed without warning and you would lose this info after a long training session.
I tried to pickle the ModelCheckpoint object but got:
TypeError: can't pickle _thread.lock objects
Such that i can reuse this same object when I bring my notebook back. Is there a good way to do this? You can try to reproduce by:
chkpt_cb = tf.keras.callbacks.ModelCheckpoint('model.{epoch:02d}-{val_loss:.4f}.h5',
monitor='val_loss',
verbose=1,
save_best_only=True)
with open('chkpt_cb.pickle', 'w') as f:
pickle.dump(chkpt_cb, f, protocol=pickle.HIGHEST_PROTOCOL)
Upvotes: 3
Views: 3753
Reputation: 3971
If callback object is not to be pickled (due to thread issue and not advisable), I can pickle this instead:
best = chkpt_cb.best
This stores the best monitored metrics that callback has seen, and it is a float, which you can pickle and reload next time, and then do this:
chkpt_cb.best = best # if chkpt_cb is a brand new object you create when colab killed your session.
This is my own setup:
# All paths should be on Google Drive, I omitted it here for simplicity.
chkpt_cb = tf.keras.callbacks.ModelCheckpoint(filepath='model.{epoch:02d}-{val_loss:.4f}.h5',
monitor='val_loss',
verbose=1,
save_best_only=True)
if os.path.exists('chkpt_cb.best.pickle'):
with open('chkpt_cb.best.pickle', 'rb') as f:
best = pickle.load(f)
chkpt_cb.best = best
def save_chkpt_cb():
with open('chkpt_cb.best.pickle', 'wb') as f:
pickle.dump(chkpt_cb.best, f, protocol=pickle.HIGHEST_PROTOCOL)
save_chkpt_cb_callback = tf.keras.callbacks.LambdaCallback(
on_epoch_end=lambda epoch, logs: save_chkpt_cb()
)
history = model.fit_generator(generator=train_data_gen,
validation_data=dev_data_gen,
epochs=5,
callbacks=[chkpt_cb, save_chkpt_cb_callback])
So even when your colab session got killed, you can still retrieve the last best metrics and inform your new instance about it, and continue training as usual. This especially help when you re-compile a stateful optimizer and may cause a regression in the loss/metric and don't want to save those models for first few epochs.
Upvotes: 5
Reputation: 4745
I think you might be misunderstanding the intended usage of the ModelCheckpoint
object. It is a callback that periodically gets called during training at a particular phase. The ModelCheckpoint callback in particular gets called after every epoch (if you keep the default period=1
) and saves your model to disk in the filename you specify to the filepath
argument. The model is saved in the same way described here. Then if you want to load that model later, you can do something like
from keras.models import load_model
model = load_model('my_model.h5')
Other answers on SO provide nice guidance and examples for continuing training from a saved model, for example: Loading a trained Keras model and continue training. Importantly, the saved H5 file stores everything about your model that is needed to continue training.
As suggested in the Keras documentation, you should not use pickle to serialize your model. Simply register the ModelCheckpoint callback with your 'fit' function:
chkpt_cb = tf.keras.callbacks.ModelCheckpoint('model.{epoch:02d}-{val_loss:.4f}.h5',
monitor='val_loss',
verbose=1,
save_best_only=True)
model.fit(x_train, y_train,
epochs=100,
steps_per_epoch=5000,
callbacks=[chkpt_cb])
Your model will be saved in an H5 file named as you have it, with the epoch number and loss values automatically formated for you. For example, your saved file for the 5th epoch with loss 0.0023 would look like model.05-.0023.h5
, and since you set save_best_only=True
, the model will only be saved if your loss is better than the previously saved one so you don't pollute your directory with a bunch of unneeded model files.
Upvotes: 3