Aliyu Tetengi
Aliyu Tetengi

Reputation: 19

Using best saved weights for the next epoch in tensorflow

I'm training a CNN model for image classification. The training accuracy is improving consistently but the validation accuracy fluctuating. I want to save best weight, and use it if a bad accuracy is encountered. I can save the best weight but don't know how to use it for the next epoch if bad accuracy is encountered. How can I achieve this with tensorflow.

Upvotes: 0

Views: 799

Answers (1)

Gerry P
Gerry P

Reputation: 8112

I have implement this idea. It seemed to me that if your validation loss increased you have moved to a point in Nspace(N being the number of trainable parameters) that is less favorable than the point in Nspace you were at in the previous epoch. Therefore it is best to reset the weights to those of the previous epoch. However at this point the learning rate should be reduced or else on the next epoch you will end up in the same less favorable point. The same is NOT exactly for monitoring accuracy but if that is what you want to do the code below will do that for you. To use the callback use the code

callbacks=[DWELL(model, monitor_acc, factor, verbose)]

where model is the name of your compiled model. monitor_acc is a boolean. If set to true the training accuracy is monitored. If on the current epoch the training accuracy has decreased, then the model weights will be set to those of the previous epoch and the learning rate will be reduced. If monitor_acc is set to False, if the validation loss is higher on the current epoch than it was on the previous epoch the same procedure is followed. factor is a float between 0 and 1. When the metric being monitored does not improve for the current epoch the model learning rate will be set as new_lr=current_lr * factor. I typically set factor at .5. Verbose is a boolean. If set to True, a printout will occur during training if the metric being monitored does not improve. The print out advises that the model weights have been set back to those of the previous epoch and prints the new reduced learning rate value. If verbose is set to False no printout is produced. Below is an example of use:

callbacks=[DWELL(my_model, False, .5, True)]

Be sure to set callbacks=callbacks in model.fit

The code for the callback is shown below:

class DWELL(keras.callbacks.Callback):
    def __init__(self,model, monitor_acc,  factor, verbose):
        super(DWELL, self).__init__()
        self.model=model
        self.initial_lr=float(tf.keras.backend.get_value(model.optimizer.lr)) # get the initiallearning rate and save it  
        self.lowest_vloss=np.inf # set lowest validation loss to infinity initially
        self.best_weights=self.model.get_weights() # set best weights to model's initial weights 
        self.verbose=verbose
        self.monitor_acc= monitor_acc
        self.highest_acc=0
    def on_epoch_end(self, epoch, logs=None):  # method runs on the end of each epoch
        lr=float(tf.keras.backend.get_value(self.model.optimizer.lr)) # get the current learning rate        
        vloss=logs.get('val_loss')  # get the validation loss for this epoch 
        acc=logs.get('accuracy')
        if self.monitor_acc==False: # monitor validation loss
            if vloss>self.lowest_vloss:
                self.model.set_weights(self.best_weights)
                new_lr=lr * factor
                tf.keras.backend.set_value(self.model.optimizer.lr, new_lr)
                if self.verbose:
                    print( '\n model weights reset to best weights and reduced lr to ', new_lr, flush=True)
            else:
                self.lowest_vloss=vloss
        else:
            if acc< self.highest_acc: # monitor training accuracy
                self.model.set_weights(self.best_weights)
                new_lr=lr * factor
                tf.keras.backend.set_value(self.model.optimizer.lr, new_lr)
                if self.verbose:
                    print( '\n model weights reset to best weights and reduced lr to ', new_lr, flush=True)
            else:
                self.highest_acc=acc       

Below is a sample printout produced during training that shows what results when the validation loss for the current epoch exceeds that of the previous epoch

Epoch 23/40
25/25 [==============================] - 3s 110ms/step - loss: 0.5927 - accuracy: 0.9825 - val_loss: 0.6827 - val_accuracy: 0.9000
Epoch 24/40
24/25 [===========================>..] - ETA: 0s - loss: 0.5812 - accuracy: 0.9869
 model weights reset to best weights and reduced lr to  0.0012499999720603228
25/25 [==============================] - 2s 86ms/step - loss: 0.5821 - accuracy: 0.9869 - val_loss: 0.6846 - val_accuracy: 0.9500
Epoch 25/40
25/25 [==============================] - 2s 86ms/step - loss: 0.5646 - accuracy: 0.9958 - val_loss: 0.6772 - val_accuracy: 0.9250

My advice is to ALWAYS monitor the validation loss it is a better measure of model performance than training accuracy. A nice feature of the callback is that at the end of training if you set monitor_acc=False, your model weights are always set to the weights of the epoch with the lowest validation loss.

Upvotes: 1

Related Questions