Mark Woodward
Mark Woodward

Reputation: 938

Does tf.estimator.Estimator.train() maintain input_fn state

I have been using my own Estimator/Experiment like code for over a year, but I want to finally jump on the Dataset+Estimator bandwagon.

I would like to do something like the following:

for _ in range(N):
  estimator.train(train_input_fn, steps=1000)
  estimator.evaluate(validation_input_fn)

Where train_input_fn creates a tf.data.Dataset that loops over the training set forever, and validation_input_fn creates a tf.data.Dataset that does one pass of the validation set.

Does train() maintain the state of train_input_fn across calls (i.e. only call it once if the reference matches)? Is this how people are doing their training loops with Estimator?

Upvotes: 3

Views: 3258

Answers (2)

syltruong
syltruong

Reputation: 2723

You can now also use the train_and_evaluate method from the Estimator API.

This is how it works:

estimator = tf.estimator.Estimator(
   model_fn=model_fn,
   model_dir=self.model_dir,
   params=params
)

train_spec = tf.estimator.TrainSpec(input_fn, max_steps=N)

eval_spec = tf.estimator.EvalSpec(
   validation_input_fn,
   steps=None,
   start_delay_secs=120, # start evaluating 120 seconds after beginning of training
   throttle_secs=600 # evaluate every 600 seconds
)


tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)

Note that the number of steps between evaluation depends on computation time and not on global_step.

Upvotes: 2

Mark Woodward
Mark Woodward

Reputation: 938

As I mentioned in my comment above, it looks like it does not save state across calls to estimator.train().

A solution that I am going with, and possibly the intended method, is to pass evaluation listeners to estimator.train(). For example,

class EvalCheckpointSaverListener(tf.train.CheckpointSaverListener):
  def __init__(self, estimator, input_fn):
    self.estimator = estimator
    self.input_fn = input_fn

  def after_save(self, session, global_step):
    self.estimator.evaluate(self.input_fn)

estimator.train(
  input_fn=lambda:_train_input_fn(...),
  max_steps=N,
  saving_listeners=[
    EvalCheckpointSaverListener(
      estimator,
      lambda:_eval_input_fn(...), 
    ),
  ],
)

Upvotes: 5

Related Questions