Agrippa
Agrippa

Reputation: 485

Keras: Optimal epoch selection

I'm trying to write some logic that selects the best epoch to run a neural network in Keras. My code saves the training loss and the test loss for a set number of epochs and then picks the best fitting epoch according to some logic. The code looks like this:

ini_epochs = 100

df_train_loss = DataFrame(data=history.history['loss'], columns=['Train_loss']); 
df_test_loss = DataFrame(data=history.history['val_loss'], columns=['Test_loss']);
df_loss = concat([df_train_loss,df_test_loss], axis=1)

Min_loss = max(df_loss['Test_loss'])
for i in range(ini_epochs):
    Test_loss = df_loss['Test_loss'][i];
    Train_loss = df_loss['Train_loss'][i]; 
    if Test_loss >  Train_loss and Test_loss < Min_loss:
        Min_loss = Test_loss;

The idea behind the logic is this; to get the best model, the epoch selected should select the model with the lowest loss value, but it must be above the training loss value to avoid overfitting.

In general, this epoch selection method works OK. However, if the test loss value is below the train loss from the start, then this method picks an epoch of zero (see below). enter image description here

Now I could add another if statement assessing whether the difference between the test and train losses are positive or negative, and then write logic for each case, but what happens if the difference starts positive and then ends up negative. I get confused and haven't been able to write effective code.

So, my questions are:

1) Can you show me how you what code you would write to to account for the situation show in the graph (and for the case where the test and train loss curves cross). I'd say the strategy would be to take the value that with the minimum difference.

2) There is a good chance that I'm going about this the wrong way. I know Keras has a callbacks feature but I don't like the idea of using the save_best_only feature because it can save overfitted models. Any advice on a more efficient epoch selection method would be great.

Upvotes: 6

Views: 10797

Answers (2)

DINA TAKLIT
DINA TAKLIT

Reputation: 8388

Here is a simple example illustrate how to use early stooping in Keras:

  • First necessarily import:

    from keras.callbacks import EarlyStopping, ModelCheckpoint
    
  • Setup Early Stopping

    # Set callback functions to early stop training and save the best model so far
    callbacks = [EarlyStopping(monitor='val_loss', patience=2),
             ModelCheckpoint(filepath='best_model.h5', monitor='val_loss', save_best_only=True)]
    
  • Train neural network

    history = network.fit(train_features, # Features
                      train_target, # Target vector
                      epochs=20, # Number of epochs
                      callbacks=callbacks, # Early stopping
                      verbose=0, # Print description after each epoch
                      batch_size=100, # Number of observations per batch
                      validation_data=(test_features, test_target)) # Data for evaluation
    

See the full example here.

Please also check :Stop Keras Training when the network has fully converge; the best answer of Daniel.

Upvotes: 3

Ramineni Ravi Teja
Ramineni Ravi Teja

Reputation: 3906

Use EarlyStopping which is available in Keras. Early stopping is basically stopping the training once your loss starts to increase (or in other words validation accuracy starts to decrease). use ModelCheckpoint to save the model wherever you want.

from keras.callbacks import EarlyStopping, ModelCheckpoint

STAMP = 'simple_lstm_glove_vectors_%.2f_%.2f'%(rate_drop_lstm,rate_drop_dense)
early_stopping =EarlyStopping(monitor='val_loss', patience=5)
bst_model_path = STAMP + '.h5'
model_checkpoint = ModelCheckpoint(bst_model_path, save_best_only=True, save_weights_only=True)

hist = model.fit(data_train, labels_train, \
        validation_data=(data_val, labels_val), \
        epochs=50, batch_size=256, shuffle=True, \
         callbacks=[early_stopping, model_checkpoint])

model.load_weights(bst_model_path)

refer to this link for more info

Upvotes: 4

Related Questions