
Reputation: 515

TF Keras ModelCheckpoint filepath batch number

I am using ModelCheckpoint to save checkpoints every 500 batches in every epoch. It is documented here https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/ModelCheckpoint.

How would I set filepath to include the batch number? I know I can use {epoch} and parameters in logs.

Upvotes: 3

Views: 1411

Answers (2)

Felipe Miranda
Felipe Miranda

Reputation: 161

Assuming you are using a tf.keras.callbacks.ModelCheckpoint with save_freq = int (which is required to save after a certain number of batches), you can create a class that inherits from ModelCheckpoint and modify the class method on_train_batch_end

class CustomCallback(tf.keras.callbacks.ModelCheckpoint):
    def __init__(self, filepath, save_freq):
        self.model_name = filepath
        self.save_freq = save_freq
        super().__init__(self.model_name, save_freq=self.save_freq)
    def on_train_batch_end(self, batch, logs=None):
        if self._should_save_on_batch(batch):
            filename = self.model_name + "epoch_" + str(self._current_epoch+1) + "_batch_ " + str(batch+1) + '.tf'
            print("\nsaved checkpoint: " + filename + "\n")

Then add an instance of this class in model.fit.

SAVE_FREQ = 200 # number of batches 
custom_callback = CustomCallback(filepath="checkpoint_", save_freq=SAVE_FREQ)
model.fit(..., callbacks=[custom_callback])

This will add both the epoch and the batch number to the checkpoint filename.

Epoch 1/3
199/422 [=============>................] - ETA: 6s - loss: 0.0261 - accuracy: 0.9915
saved checkpoint: checkpoint_epoch_0_batch_200.tf

399/422 [===========================>..] - ETA: 0s - loss: 0.0263 - accuracy: 0.9914
saved checkpoint: checkpoint_epoch_0_batch_400.tf

422/422 [==============================] - 13s 31ms/step - loss: 0.0264 - accuracy: 0.9914 - val_loss: 0.0311 - val_accuracy: 0.9920
Epoch 2/3
177/422 [===========>..................] - ETA: 7s - loss: 0.0254 - accuracy: 0.9913
saved checkpoint: checkpoint_epoch_1_batch_178.tf

377/422 [=========================>....] - ETA: 1s - loss: 0.0252 - accuracy: 0.9912
saved checkpoint: checkpoint_epoch_1_batch_378.tf

422/422 [==============================] - 13s 32ms/step - loss: 0.0252 - accuracy: 0.9912 - val_loss: 0.0306 - val_accuracy: 0.9925
Epoch 3/3
156/422 [==========>...................] - ETA: 7s - loss: 0.0253 - accuracy: 0.9914
saved checkpoint: checkpoint_epoch_2_batch_156.tf

355/422 [========================>.....] - ETA: 2s - loss: 0.0246 - accuracy: 0.9919
saved checkpoint: checkpoint_epoch_2_batch_356.tf

422/422 [==============================] - 13s 31ms/step - loss: 0.0245 - accuracy: 0.9919 - val_loss: 0.0294 - val_accuracy: 0.9922

Upvotes: 6

Dhanushka Sandaruwan
Dhanushka Sandaruwan

Reputation: 340

This might be helpful but the question is not clear. Under Callback class, there are numerous functions for your desired requirement.

Sample Code

class WeightsSaver(Callback):
  def __init__(self, N):
    self.N = N
    self.epoch = 0

  def on_epoch_end(self, epoch, logs={}):
    if self.epoch % self.N == 0:
        name = ('weights%04d.hdf5') % self.epoch
    self.epoch += 1

callbacks_list = [WeightsSaver(10)] #save every 10 models

Upvotes: -1

Related Questions