Reputation: 938
I have been using my own Estimator/Experiment like code for over a year, but I want to finally jump on the Dataset+Estimator bandwagon.
I would like to do something like the following:
for _ in range(N):
estimator.train(train_input_fn, steps=1000)
estimator.evaluate(validation_input_fn)
Where train_input_fn
creates a tf.data.Dataset
that loops over the training set forever, and validation_input_fn
creates a tf.data.Dataset
that does one pass of the validation set.
Does train()
maintain the state of train_input_fn
across calls (i.e. only call it once if the reference matches)? Is this how people are doing their training loops with Estimator?
Upvotes: 3
Views: 3258
Reputation: 2723
You can now also use the train_and_evaluate
method from the Estimator
API.
This is how it works:
estimator = tf.estimator.Estimator(
model_fn=model_fn,
model_dir=self.model_dir,
params=params
)
train_spec = tf.estimator.TrainSpec(input_fn, max_steps=N)
eval_spec = tf.estimator.EvalSpec(
validation_input_fn,
steps=None,
start_delay_secs=120, # start evaluating 120 seconds after beginning of training
throttle_secs=600 # evaluate every 600 seconds
)
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
Note that the number of steps between evaluation depends on computation time and not on global_step
.
Upvotes: 2
Reputation: 938
As I mentioned in my comment above, it looks like it does not save state across calls to estimator.train()
.
A solution that I am going with, and possibly the intended method, is to pass evaluation listeners to estimator.train()
. For example,
class EvalCheckpointSaverListener(tf.train.CheckpointSaverListener):
def __init__(self, estimator, input_fn):
self.estimator = estimator
self.input_fn = input_fn
def after_save(self, session, global_step):
self.estimator.evaluate(self.input_fn)
estimator.train(
input_fn=lambda:_train_input_fn(...),
max_steps=N,
saving_listeners=[
EvalCheckpointSaverListener(
estimator,
lambda:_eval_input_fn(...),
),
],
)
Upvotes: 5