Reputation: 453
I am kinda new to TensorFlow world but have written some programs in Keras. Since TensorFlow 2 is officially similar to Keras, I am quite confused about what is the difference between tf.keras.callbacks.ModelCheckpoint and tf.train.Checkpoint. If anybody can shed light on this, I would appreciate it.
Upvotes: 3
Views: 3100
Reputation: 911
I also had a hard time differentiating between the checkpoint objects used when I looked at other people's code, so I wrote down some notes about when to use which one and how to use them in general. Either-way, I think it might be useful for other people having the same issue:
These are 2 ways of saving your model's checkpoints, each is for a different use case:
This is use-full when you are managing the training loops yourself.
You use them like this:
1.1) CheckpointDefinition from the docs: "A Checkpoint object can be constructed to save either a single or group of trackable objects to a checkpoint file".
How to initialise it:
ckpt = Checkpoint(discr_opt=discr_opt,
genrt_opt=genrt_opt,
wgan = wgan,
d_model = d_model,
g_model = g_model)
1.2) CheckpointManager
This literally manages the checkpoints you have defined to be stored at a location and things like how many to to keep. Definition from the docs: "Manages multiple checkpoints by keeping some and deleting unneeded ones"
How to initialise it:
manager = CheckpointManager(ckpt, "training_checkpoints_wgan", max_to_keep=3)
How to use it:
manager.save()
You want to use this callback when you are not managing epoch iterations yourself. For example when you have setup a relatively simple Sequential model and you call model.fit(), which manages the training process for you.
Definition from the docs: "Callback to save the Keras model or model weights at some frequency."
How to initialise it:
Pass in the path where to save the model
The option save_weights_only is set to False by default:
The option save_best_only is set to False by default:
verbose is set to 0 (False), so you can update this to 1 to validate it
mc = ModelCheckpoint("training_checkpoints/cp.ckpt", save_best_only=True, save_weights_only=False)
How to use it:
model.fit(X, y, epochs=100, callbacks=[mc])
Upvotes: 1
Reputation: 579
It depends on whether a custom training loop is required. In most cases, it's not and you can just call model.fit()
and pass tf.keras.callbacks.ModelCheckpoint
. If you do need to write your custom training loop, then you have to use tf.train.Checkpoint
(and tf.train.CheckpointManager
) since there's no callback mechanism.
Upvotes: 4
Reputation: 4251
TensorFlow is a 'computation' library and Keras is a Deep Learning library which can work with TF or PyTorch, etc. So what TF provides is a more generic not-so-customized-for-deep-learning version. If you just compare the docs you can see how more comprehensive and customized ModelCheckpoint
is. Checkpoint just reads and writes stuff from/to disk. ModelCheckpoint
is much smarter!
Also, ModelCheckpoint
is a callback. It means you can just make an instance of it and pass it to the fit
function:
model_checkpoint = ModelCheckpoint(...)
model.fit(..., callbacks=[..., model_checkpoint, ...], ...)
I took a quick look at Keras's implementation of ModelCheckpoint
, it calls either save
or save_weights
method on Model
which in some cases uses TensorFlow's CheckPoint
itself. So it is not a wrapper per se but certainly is on a lower level of abstraction -- more specialized for saving Keras models.
Upvotes: 1