Reputation: 19
I'm training a CNN model for image classification. The training accuracy is improving consistently but the validation accuracy fluctuating. I want to save best weight, and use it if a bad accuracy is encountered. I can save the best weight but don't know how to use it for the next epoch if bad accuracy is encountered. How can I achieve this with tensorflow.
Upvotes: 0
Views: 799
Reputation: 8112
I have implement this idea. It seemed to me that if your validation loss increased you have moved to a point in Nspace(N being the number of trainable parameters) that is less favorable than the point in Nspace you were at in the previous epoch. Therefore it is best to reset the weights to those of the previous epoch. However at this point the learning rate should be reduced or else on the next epoch you will end up in the same less favorable point. The same is NOT exactly for monitoring accuracy but if that is what you want to do the code below will do that for you. To use the callback use the code
callbacks=[DWELL(model, monitor_acc, factor, verbose)]
where model is the name of your compiled model. monitor_acc is a boolean. If set to true the training accuracy is monitored. If on the current epoch the training accuracy has decreased, then the model weights will be set to those of the previous epoch and the learning rate will be reduced. If monitor_acc is set to False, if the validation loss is higher on the current epoch than it was on the previous epoch the same procedure is followed. factor is a float between 0 and 1. When the metric being monitored does not improve for the current epoch the model learning rate will be set as new_lr=current_lr * factor. I typically set factor at .5. Verbose is a boolean. If set to True, a printout will occur during training if the metric being monitored does not improve. The print out advises that the model weights have been set back to those of the previous epoch and prints the new reduced learning rate value. If verbose is set to False no printout is produced. Below is an example of use:
callbacks=[DWELL(my_model, False, .5, True)]
Be sure to set callbacks=callbacks in model.fit
The code for the callback is shown below:
class DWELL(keras.callbacks.Callback):
def __init__(self,model, monitor_acc, factor, verbose):
super(DWELL, self).__init__()
self.model=model
self.initial_lr=float(tf.keras.backend.get_value(model.optimizer.lr)) # get the initiallearning rate and save it
self.lowest_vloss=np.inf # set lowest validation loss to infinity initially
self.best_weights=self.model.get_weights() # set best weights to model's initial weights
self.verbose=verbose
self.monitor_acc= monitor_acc
self.highest_acc=0
def on_epoch_end(self, epoch, logs=None): # method runs on the end of each epoch
lr=float(tf.keras.backend.get_value(self.model.optimizer.lr)) # get the current learning rate
vloss=logs.get('val_loss') # get the validation loss for this epoch
acc=logs.get('accuracy')
if self.monitor_acc==False: # monitor validation loss
if vloss>self.lowest_vloss:
self.model.set_weights(self.best_weights)
new_lr=lr * factor
tf.keras.backend.set_value(self.model.optimizer.lr, new_lr)
if self.verbose:
print( '\n model weights reset to best weights and reduced lr to ', new_lr, flush=True)
else:
self.lowest_vloss=vloss
else:
if acc< self.highest_acc: # monitor training accuracy
self.model.set_weights(self.best_weights)
new_lr=lr * factor
tf.keras.backend.set_value(self.model.optimizer.lr, new_lr)
if self.verbose:
print( '\n model weights reset to best weights and reduced lr to ', new_lr, flush=True)
else:
self.highest_acc=acc
Below is a sample printout produced during training that shows what results when the validation loss for the current epoch exceeds that of the previous epoch
Epoch 23/40
25/25 [==============================] - 3s 110ms/step - loss: 0.5927 - accuracy: 0.9825 - val_loss: 0.6827 - val_accuracy: 0.9000
Epoch 24/40
24/25 [===========================>..] - ETA: 0s - loss: 0.5812 - accuracy: 0.9869
model weights reset to best weights and reduced lr to 0.0012499999720603228
25/25 [==============================] - 2s 86ms/step - loss: 0.5821 - accuracy: 0.9869 - val_loss: 0.6846 - val_accuracy: 0.9500
Epoch 25/40
25/25 [==============================] - 2s 86ms/step - loss: 0.5646 - accuracy: 0.9958 - val_loss: 0.6772 - val_accuracy: 0.9250
My advice is to ALWAYS monitor the validation loss it is a better measure of model performance than training accuracy. A nice feature of the callback is that at the end of training if you set monitor_acc=False, your model weights are always set to the weights of the epoch with the lowest validation loss.
Upvotes: 1