rsky
rsky

Reputation: 41

How to monitor loss & val_loss at the same time to avoid overfitting Neural network to either train set or test set?

I've been joining this hackathon and playing with keras callbacks and neural network, may I know if there is a way to monitor not only loss or val_loss but BOTH of them to avoid overfitting either the test or train set? e.g: can i put a function for the monitor field instead of just one field name?

If I want to monitor val_loss to pick the lowest but I also want a second criteria to pick the minimum difference between val_loss and loss.

Upvotes: 2

Views: 1519

Answers (3)

ClaudiaR
ClaudiaR

Reputation: 3414

I have an answer to a problem that is pretty similar to this, here.

Basically, it is not possible to monitor multiple metrics with keras callbacks. However, you could define a custom callback (see the documentation for more info) that can access the logs at each epoch and do some operations.

Let's say if you want to monitor loss and val_loss you can do something like this:

import tensorflow as tf
from tensorflow import keras

class CombineCallback(tf.keras.callbacks.Callback):

    def __init__(self, **kargs):
        super(CombineCallback, self).__init__(**kargs)

    def on_epoch_end(self, epoch, logs={}):
        logs['combine_metric'] = logs['val_loss'] + logs['loss']

Side note: the most important thing in my opinion is to monitor the validation loss. Train loss of course will keep dropping, so it is not really that meaningful to observe. If you really want to monitor them both I suggest you add a multiplicative factor and give more weight to validation loss. In this case:

class CombineCallback(tf.keras.callbacks.Callback):

    def __init__(self, **kargs):
        super(CombineCallback, self).__init__(**kargs)

    def on_epoch_end(self, epoch, logs={}):
        factor = 0.8
        logs['combine_metric'] = factor * logs['val_loss'] + (1-factor) * logs['loss']

Then, if you only want to monitor this new metric during the training, you can use it like this:

model.fit(
    ...
    callbacks=[CombineCallback()],
)

Instead, if you also want to stop the training using the new metric, you should combine the new callback with the early stopping callback:

combined_cb = CombineCallback()
early_stopping_cb = keras.callbacks.EarlyStopping(monitor="combine_metric")
model.fit(
    ...
    callbacks=[combined_cb, early_stopping_cb],
)

Be sure to get the CombinedCallback before the early stopping callback in the callbacks list.

Moreover, you can draw more inspiration here.

Upvotes: 2

Gerry P
Gerry P

Reputation: 8102

Below is a Keras custom callback that should do the job. The callback monitors both the training loss and the validation loss. The form of the callback is callbacks=[SOMT(model, train_thold, valid_thold)] where:

  • model is the name of your complied model
  • train_thold is a float. It is the value of accuracy (in Percent) that must be achieved by the model in order to conditionally stop training
  • valid_threshold is a float. It is the value of validation accuracy (in Percent) that must be achieved by the model in order to conditionally stop training Note to stop training BOTH the train_thold and valid_thold must be exceeded in the SAME epoch.
    If you want to stop training based soley on the training accuracy set the valid_thold to 0.0.
    Similarly if you want to stop training on just validation accuracy set train_thold= 0.0.

Note if both thresholds are not achieved in the same epoch training will continue until the value of epochs specified in model.fit is reached.
For example lets take the case that you want to stop training when the
training accuracy has reached or exceeded 95 % and the validation accuracy has achieved at least 85%
then the code would be callbacks=[SOMT(my_model, .95, .85)]

class SOMT(keras.callbacks.Callback):
    def __init__(self, model,  train_thold, valid_thold):
        super(SOMT, self).__init__()
        self.model=model        
        self.train_thold=train_thold
        self.valid_thold=valid_thold
        
    def on_train_begin(self, logs=None):
        print('Starting Training - training will halt if training accuracy achieves or exceeds ', self.train_thold)
        print ('and validation accuracy meets or exceeds ', self.valid_thold) 
        msg='{0:^8s}{1:^12s}{2:^12s}{3:^12s}{4:^12s}{5:^12s}'.format('Epoch', 'Train Acc', 'Train Loss','Valid Acc','Valid_Loss','Duration')
        print (msg)                                                                                    
            
    def on_train_batch_end(self, batch, logs=None):
        acc=logs.get('accuracy')* 100  # get training accuracy 
        loss=logs.get('loss')
        msg='{0:1s}processed batch {1:4s}  training accuracy= {2:8.3f}  loss: {3:8.5f}'.format(' ', str(batch),  acc, loss)
        print(msg, '\r', end='') # prints over on the same line to show running batch count 
        
    def on_epoch_begin(self,epoch, logs=None):
        self.now= time.time()
        
    def on_epoch_end(self,epoch, logs=None): 
        later=time.time()
        duration=later-self.now 
        tacc=logs.get('accuracy')           
        vacc=logs.get('val_accuracy')
        tr_loss=logs.get('loss')
        v_loss=logs.get('val_loss')
        ep=epoch+1
        print(f'{ep:^8.0f} {tacc:^12.2f}{tr_loss:^12.4f}{vacc:^12.2f}{v_loss:^12.4f}{duration:^12.2f}')
        if tacc>= self.train_thold and vacc>= self.valid_thold:
            print( f'\ntraining accuracy and validation accuracy reached the thresholds on epoch {epoch + 1}' )
            self.model.stop_training = True # stop training

Upvotes: 1

DavidH
DavidH

Reputation: 791

You can choose between two approaches:

  1. Create a custom metric to record the metric you want, by subclassing tf.keras.metrics.Metric. See https://www.tensorflow.org/api_docs/python/tf/keras/metrics/Metric for an example.
    You can then use your metric in standard callbacks e.g. EarlyStopping()

  2. Create a custom callback to do the calculation (and take the action) you want, by subclassing tf.keras.callbacks.CallBack. See https://www.tensorflow.org/guide/keras/custom_callback for how to do this.

Upvotes: 1

Related Questions