Arshad
Arshad

Reputation: 361

Making use of previous saved model checkpoint for training on new data

I was experimenting with model checkpoints with the code below:

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.callbacks import Callback, TensorBoard, ModelCheckpoint

def get_model():
    # Create a simple model.
    inputs = keras.Input(shape=(32,))
    outputs = keras.layers.Dense(1)(inputs)
    model = keras.Model(inputs, outputs)
    model.compile(optimizer = "adam", loss = "mean_squared_error")
    return model

# Creating the model
model = get_model()

# Checkpoint of model weights
weights_filename = "weights/model_best_weights"
checkpoint = ModelCheckpoint(weights_filename, monitor = 'loss', verbose = 1, save_best_only = True, save_weights_only = True, mode = 'auto', period = 1)


# Train the model.
test_input = np.random.random((128, 32))
test_target = np.random.random((128, 1))
model.fit(test_input, test_target, verbose = 2, callbacks=[checkpoint])

# Calling `save('my_model')` creates a SavedModel folder `my_model`.
model.save("my_model")

After training the model,

WARNING:tensorflow:`period` argument is deprecated. Please use `save_freq` to specify the frequency in number of batches seen.
4/4 - 0s - loss: 4.5422

Epoch 00001: loss improved from inf to 4.54224, saving model to weights/model_best_weights
INFO:tensorflow:Assets written to: my_model/assets
# It can be used to reconstruct the model identically.
reconstructed_model = keras.models.load_model("my_model")
weights_filename = "weights/model_best_weights"
reconstructed_model.load_weights(weights_filename)  # Loading previous saved weights

#This is creating the problem -->  Problem of creating new checkpoint
checkpoint = ModelCheckpoint(weights_filename, monitor = 'loss', verbose = 1, save_best_only = True, save_weights_only = True, mode = 'auto', period = 1)

Using the reloaded model for training another data:

# The reconstructed model is already compiled and has retained the optimizer
# state, so training can resume:
test_input = np.random.random((128, 32))
test_target = np.random.random((128, 1))

reconstructed_model.fit(test_input, test_target, verbose = 2, callbacks = [checkpoint])

The loss comes as

4/4 - 0s - loss: 3.8699
 
Epoch 00001: loss improved from inf to 3.86991, saving model to
weights/model_best_weights

The issue here is that I am instantiating a ModelCheckpoint again in order to save the best weight checkpoints.

Shouldn't the training loss start from the last saved best checkpoint that was created earlier?

Upvotes: 0

Views: 787

Answers (1)

Innat
Innat

Reputation: 17219

Okay, the thing is when you create an instance from the ModelCheckpoint(Callback) class, it sets the monitoring parameter to inf (-/+). That's why when you second time creates the instance of this class, it sets such init value for the monitoring parameter. You can see the source code here and that further lead this logs.

...
    if mode == 'min':
      self.monitor_op = np.less
      self.best = np.Inf
    elif mode == 'max':
      self.monitor_op = np.greater
      self.best = -np.Inf
    else:
      if 'acc' in self.monitor or self.monitor.startswith('fmeasure'):
        self.monitor_op = np.greater
        self.best = -np.Inf
      else:
        self.monitor_op = np.less
        self.best = np.Inf
...

Upvotes: 1

Related Questions