Connor Hale
Connor Hale

Reputation: 85

EarlyStopping not stopping model despite loading best weights

I am running an image classifying program using tf.keras, and I am trying to determine the error curves for accuracy and val_accuracy. However, when I add an early stopping callback, the model does not stop training, even when it passes the patience threshold.

I have tried changing early stopping monitor value, but I found that it is monitoring the correct value, since I reach this point in the tensorflow.python.keras.callbacks.py file

    else:
      self.wait += 1
      if self.wait >= self.patience:
        self.stopped_epoch = epoch
        self.model.stop_training = True
        if self.restore_best_weights:
          if self.verbose > 0:
            print('Restoring model weights from the end of the best epoch.')
          self.model.set_weights(self.best_weights)

My output shows the print line, so I have clearly hit the self.model.stop_training = True line, however my model continues. Here is an example of the model running despite reaching the early stop point. You can see at the end of epoch 9, it "Restores model weights from the end of the best epoch." However it keeps running the 10th epoch after that.

Epoch 1/10
 9/10 [==========================>...] - ETA: 1s - loss: 1.1147 - categorical_accuracy: 0.6058
Epoch 00001: val_categorical_accuracy improved from -inf to 0.25000, saving model to /home/chale/ml_classify/data/best.weights.hdf5
10/10 [==============================] - 29s 3s/step - loss: 1.0876 - categorical_accuracy: 0.6013 - val_loss: 60.9186 - val_categorical_accuracy: 0.2500
Epoch 2/10
 9/10 [==========================>...] - ETA: 0s - loss: 1.2638 - categorical_accuracy: 0.5694
Epoch 00002: val_categorical_accuracy did not improve from 0.25000
10/10 [==============================] - 7s 747ms/step - loss: 1.2278 - categorical_accuracy: 0.5750 - val_loss: 147.1493 - val_categorical_accuracy: 0.2396
Epoch 3/10
 9/10 [==========================>...] - ETA: 0s - loss: 0.5760 - categorical_accuracy: 0.8321
Epoch 00003: val_categorical_accuracy improved from 0.25000 to 0.26042, saving model to /home/chale/ml_classify/data/best.weights.hdf5
10/10 [==============================] - 10s 972ms/step - loss: 0.5569 - categorical_accuracy: 0.8288 - val_loss: 21.9862 - val_categorical_accuracy: 0.2604
Epoch 4/10
 9/10 [==========================>...] - ETA: 0s - loss: 0.4401 - categorical_accuracy: 0.8681
Epoch 00004: val_categorical_accuracy improved from 0.26042 to 0.30208, saving model to /home/chale/ml_classify/data/best.weights.hdf5
10/10 [==============================] - 9s 897ms/step - loss: 0.4383 - categorical_accuracy: 0.8687 - val_loss: 146.7307 - val_categorical_accuracy: 0.3021
Epoch 5/10
 9/10 [==========================>...] - ETA: 0s - loss: 0.4499 - categorical_accuracy: 0.8394
Epoch 00005: val_categorical_accuracy did not improve from 0.30208
10/10 [==============================] - 7s 714ms/step - loss: 0.4218 - categorical_accuracy: 0.8493 - val_loss: 71.2797 - val_categorical_accuracy: 0.1354
Epoch 6/10
 9/10 [==========================>...] - ETA: 0s - loss: 0.5760 - categorical_accuracy: 0.8194
Epoch 00006: val_categorical_accuracy improved from 0.30208 to 0.38542, saving model to /home/chale/ml_classify/data/best.weights.hdf5
10/10 [==============================] - 10s 974ms/step - loss: 0.5342 - categorical_accuracy: 0.8313 - val_loss: 13.7430 - val_categorical_accuracy: 0.3854
Epoch 7/10
 9/10 [==========================>...] - ETA: 0s - loss: 0.3852 - categorical_accuracy: 0.9000
Epoch 00007: val_categorical_accuracy did not improve from 0.38542
10/10 [==============================] - 6s 619ms/step - loss: 0.4190 - categorical_accuracy: 0.8973 - val_loss: 164.1882 - val_categorical_accuracy: 0.2708
Epoch 8/10
 9/10 [==========================>...] - ETA: 0s - loss: 0.3401 - categorical_accuracy: 0.8905
Epoch 00008: val_categorical_accuracy did not improve from 0.38542
10/10 [==============================] - 7s 723ms/step - loss: 0.3745 - categorical_accuracy: 0.8889 - val_loss: 315.0913 - val_categorical_accuracy: 0.2708
Epoch 9/10
 9/10 [==========================>...] - ETA: 0s - loss: 0.2713 - categorical_accuracy: 0.8958
Epoch 00009: val_categorical_accuracy did not improve from 0.38542
Restoring model weights from the end of the best epoch.
10/10 [==============================] - 9s 853ms/step - loss: 0.2550 - categorical_accuracy: 0.9062 - val_loss: 340.6383 - val_categorical_accuracy: 0.2708
Epoch 10/10
 9/10 [==========================>...] - ETA: 0s - loss: 0.4282 - categorical_accuracy: 0.8759
Epoch 00010: val_categorical_accuracy did not improve from 0.38542
Restoring model weights from the end of the best epoch.
10/10 [==============================] - 8s 795ms/step - loss: 0.4260 - categorical_accuracy: 0.8758 - val_loss: 4.5791 - val_categorical_accuracy: 0.2500
Epoch 00010: early stopping

Here is the main code for the issue

        if loss == 'categorical_crossentropy':
            monitor = 'val_categorical_accuracy'
        else:
            monitor = 'val_binary_accuracy'

        early_stop = EarlyStopping(monitor=monitor, patience=3, verbose=1, restore_best_weights=True)

        checkpoint_path = '{}/best.weights.hdf5'.format(output_dir)
        best_model = ModelCheckpoint(checkpoint_path, monitor=monitor, verbose=1, save_best_only=True, mode='max')

        # reduce_lr = tensorflow.python.keras.callbacks.ReduceLROnPlateau()

        m = Metrics(labels=labels, val_data=validation_generator, batch_size=batch_size)
        history = model.fit_generator(train_generator,
                                           steps_per_epoch=steps_per_epoch,
                                           epochs=epochs,
                                           use_multiprocessing=True,
                                           validation_data=validation_generator,
                                           validation_steps=validation_steps,
                                           callbacks=[tensorboard, best_model, early_stop,
                                                      WandbCallback(data_type="image", 
                                                                    validation_data=validation_generator,
                                                                    labels=labels)])# , schedule])


        return history

and the whole code is at https://github.com/AtlasHale/ml_classify

I expected that when early stopping patience is gone past, the remaining epochs would not run. The model returned would be the best model from the weights, if there was early stopping. However, the model is not the best model, it is the last model. I would like to return the best model and skip epochs after the early stopping occurs.

Edit: After copying and adding some printing to the EarlyStopping class, I found this

Epoch 6/10
8/9 [=========================>....] - ETA: 0s - loss: 0.5594 - categorical_accuracy: 0.9062
Epoch 00006: val_categorical_accuracy did not improve from 0.27083
3 epochs since improvement to val_categorical_accuracy
Model stop_training state previously: False
Model stop_training state now: True
Restoring model weights from the end of the best epoch.
9/9 [==============================] - 8s 855ms/step - loss: 0.5511 - categorical_accuracy: 0.8889 - val_loss: 466.1678 - val_categorical_accuracy: 0.2292
Epoch 7/10
8/9 [=========================>....] - ETA: 0s - loss: 0.3544 - categorical_accuracy: 0.8992
Epoch 00007: val_categorical_accuracy did not improve from 0.27083
4 epochs since improvement to val_categorical_accuracy
Model stop_training state previously: False
Model stop_training state now: True
Restoring model weights from the end of the best epoch.

When self.model.stop_training is set as True, it does not appear to persist to the end of the next epoch. So it seems like what is happening in the callback is not being applied to the model? I am not sure. Any insight is welcome.

Upvotes: 2

Views: 1664

Answers (1)

Sasha Sokolov
Sasha Sokolov

Reputation: 1

I had the same problem. I set epochs=100 and patience=5, but every time I got 100 epochs of training.

I found, that these guys got the right behavior of EarlyStopping: https://lambdalabs.com/blog/tensorflow-2-0-tutorial-04-early-stopping/

The main hint: use min_delta param. In this case, training will be stopped if no any improvement from the previous best result after the number of epochs, which were set by patience param.

Upvotes: 0

Related Questions