Marcel
Marcel

Reputation: 41

Who to do early stopping with the evaluation loss using tf.estimator.train_and_evaluate?

I am using the Tensorflow estimator and explicitly the method tf.estimator.train_and_evaluate(). There is an early stopping hook for the training which is tf.contrib.estimator.stop_if_no_decrease_hook, but I do have the issue that the training loss is too jumpy to use for early stopping. Does anyone know how to do early stopping with tf.estimator based on the evaluation loss?

Upvotes: 4

Views: 680

Answers (1)

TF_Support
TF_Support

Reputation: 1836

You can use tf.contrib.estimator.stop_if_no_decrease_hook as indicated below:

estimator = tf.estimator.Estimator(model_fn, model_dir)

os.makedirs(estimator.eval_dir())  # TODO This should not be expected IMO.

early_stopping = tf.contrib.estimator.stop_if_no_decrease_hook(
    estimator,
    metric_name='loss',
    max_steps_without_decrease=1000,
    min_steps=100)

tf.estimator.train_and_evaluate(
    estimator,
    train_spec=tf.estimator.TrainSpec(train_input_fn, hooks=[early_stopping]),
    eval_spec=tf.estimator.EvalSpec(eval_input_fn))

But if it doesn't work for you it is better to use tf.estimator.experimental.stop_if_no_decrease_hook instead.

For example:

estimator = ...
# Hook to stop training if loss does not decrease in over 100000 steps.
hook = early_stopping.stop_if_no_decrease_hook(estimator, "loss", 100000)
train_spec = tf.estimator.TrainSpec(..., hooks=[hook])
tf.estimator.train_and_evaluate(estimator, train_spec, ...)

The early-stopping hook uses the evaluation results to decide when it's time to cut the training, but you need to pass in the number of training steps you want to monitor and keep in mind how many evaluations will happen in that number of steps. If you set the hook as hook = early_stopping.stop_if_no_decrease_hook(estimator, "loss", 10000) the hook will consider the evaluations happening in a range of 10k steps.

Read more about the documentation here: https://www.tensorflow.org/api_docs/python/tf/estimator/experimental/stop_if_no_decrease_hook and for all the early stopping functions you can use, you may refer from this https://github.com/tensorflow/estimator/blob/master/tensorflow_estimator/python/estimator/early_stopping.py

Upvotes: 2

Related Questions