superduper
superduper

Reputation: 453

tf.keras.callbacks.ModelCheckpoint vs tf.train.Checkpoint

I am kinda new to TensorFlow world but have written some programs in Keras. Since TensorFlow 2 is officially similar to Keras, I am quite confused about what is the difference between tf.keras.callbacks.ModelCheckpoint and tf.train.Checkpoint. If anybody can shed light on this, I would appreciate it.

Upvotes: 3

Views: 3100

Answers (3)

Vasco
Vasco

Reputation: 911

I also had a hard time differentiating between the checkpoint objects used when I looked at other people's code, so I wrote down some notes about when to use which one and how to use them in general. Either-way, I think it might be useful for other people having the same issue:

Saving model Checkpoints

These are 2 ways of saving your model's checkpoints, each is for a different use case:

1) Checkpoint & CheckpointManager

This is use-full when you are managing the training loops yourself.

You use them like this:

1.1) Checkpoint

Definition from the docs: "A Checkpoint object can be constructed to save either a single or group of trackable objects to a checkpoint file".

How to initialise it:

  • You can pass it key value pairs for:
    • All the custom function calls or objects that make up your model and you want to keep track of:
    • Like a generator, discriminiator, loss function, optimizer etc
ckpt = Checkpoint(discr_opt=discr_opt,
                  genrt_opt=genrt_opt,
                  wgan = wgan,
                  d_model = d_model,
                  g_model = g_model)
1.2) CheckpointManager

This literally manages the checkpoints you have defined to be stored at a location and things like how many to to keep. Definition from the docs: "Manages multiple checkpoints by keeping some and deleting unneeded ones"

How to initialise it:

  • Initialise it with the CheckPoint object you create as first argument.
  • The directory where to save the checkpoint files.
  • And you probably want to define how much you want to keep, since this can be a lot of complex models
manager = CheckpointManager(ckpt, "training_checkpoints_wgan", max_to_keep=3)

How to use it:

  • We have setup the manager object with our specified checkpoints, so it's ready to use.
  • Call this at the end of each training epoch
manager.save()

2) ModelCheckpoint (callback)

You want to use this callback when you are not managing epoch iterations yourself. For example when you have setup a relatively simple Sequential model and you call model.fit(), which manages the training process for you.

Definition from the docs: "Callback to save the Keras model or model weights at some frequency."

How to initialise it:

  • Pass in the path where to save the model

  • The option save_weights_only is set to False by default:

    • If you want to only save the weights make sure to update this
  • The option save_best_only is set to False by default:

    • If you want to only save the best model instead of all of them, you can set this to True.
  • verbose is set to 0 (False), so you can update this to 1 to validate it

mc = ModelCheckpoint("training_checkpoints/cp.ckpt", save_best_only=True, save_weights_only=False)

How to use it:

  • The model checkpoint callback is now ready to for training.
  • You pass in the object in you into your callbacks list when you fit the model:
model.fit(X, y, epochs=100, callbacks=[mc])

Upvotes: 1

Jake Wu
Jake Wu

Reputation: 579

It depends on whether a custom training loop is required. In most cases, it's not and you can just call model.fit() and pass tf.keras.callbacks.ModelCheckpoint. If you do need to write your custom training loop, then you have to use tf.train.Checkpoint (and tf.train.CheckpointManager) since there's no callback mechanism.

Upvotes: 4

Mohammad Jafar Mashhadi
Mohammad Jafar Mashhadi

Reputation: 4251

TensorFlow is a 'computation' library and Keras is a Deep Learning library which can work with TF or PyTorch, etc. So what TF provides is a more generic not-so-customized-for-deep-learning version. If you just compare the docs you can see how more comprehensive and customized ModelCheckpoint is. Checkpoint just reads and writes stuff from/to disk. ModelCheckpoint is much smarter!

Also, ModelCheckpoint is a callback. It means you can just make an instance of it and pass it to the fit function:

model_checkpoint = ModelCheckpoint(...)
model.fit(..., callbacks=[..., model_checkpoint, ...], ...)

I took a quick look at Keras's implementation of ModelCheckpoint, it calls either save or save_weights method on Model which in some cases uses TensorFlow's CheckPoint itself. So it is not a wrapper per se but certainly is on a lower level of abstraction -- more specialized for saving Keras models.

Upvotes: 1

Related Questions