Reputation: 1006
I am using the tf.estimator API to predict punctuation. I trained it with pre-processed data using TFRecords and tf.train.shuffle_batch
. Now I want to make predictions. I can do this fine feeding static NumPy data into tf.constant
and returning this from the input_fn
.
However I am working with sequence data and I need to feed one example at a time and the next input is dependent on the previous output. I also want to be able to process data input through HTTP requests.
Every time estimator.predict
is called it re-loads the checkpoint and recreates the entire graph. This is slow and expensive. So I need to be able to dynamically feed data to the input_fn
.
My current attempt is roughly this:
feature_input = tf.placeholder(tf.int32, shape=[1, MAX_SUBSEQUENCE_LEN])
q = tf.FIFOQueue(1, tf.int32, shapes=[[1, MAX_SUBSEQUENCE_LEN]])
enqueue_op = q.enqueue(feature_input)
def input_fn():
return q.dequeue()
estimator = tf.estimator.Estimator(model_fn, model_dir=model_file)
predictor = estimator.predict(input_fn=input_fn)
sess = tf.Session()
output = None
while True:
x = get_numpy_data(x, output)
if x is None:
break
sess.run(enqueue_op, {feature_input: x})
output = predictor.next()
save_to_file(output)
sess.close()
However I am getting the following error:
ValueError: Input graph and Layer graph are not the same: Tensor("EmbedSequence/embedding_lookup:0", shape=(1, 200, 128), dtype=float32) is not from the passed-in graph.
How can I asynchronously plug data into my existing graph through an input_fn
to get predictions one at a time?
Upvotes: 3
Views: 3337
Reputation: 1006
It turns out the main problem is that all tensors need to be created inside the input_fn
or they don't get added to the same graph. I needed to run an enqueue operation but it was impossible to access anything returned from the input function.
I ended up inheriting the Estimator
class and creating a custom predict function which allows me to dynamically add data to the prediction queue and return the results:
# async_estimator.py
import six
import tensorflow as tf
from tensorflow.python.estimator.estimator import Estimator
from tensorflow.python.estimator.estimator import _check_hooks_type
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.training import saver
from tensorflow.python.training import training
class AsyncEstimator(Estimator):
def async_predictor(self,
dtype,
shape=None,
predict_keys=None,
hooks=None,
checkpoint_path=None):
"""Returns a tuple of functions: first runs predicitons on the model, second cleans up
Args:
dtype: the dtype of the input
shape: the shape of the input placeholder (optional)
predict_keys: list of `str`, name of the keys to predict. It is used if
the `EstimatorSpec.predictions` is a `dict`. If `predict_keys` is used
then rest of the predictions will be filtered from the dictionary. If
`None`, returns all.
hooks: List of `SessionRunHook` subclass instances. Used for callbacks
inside the prediction call.
checkpoint_path: Path of a specific checkpoint to predict. If `None`, the
latest checkpoint in `model_dir` is used.
Returns:
(predict, finish): tuple of functions
predict: runs a single prediction and returns the results
Args:
x: NumPy array of input
Returns:
Evaluated value of the prediction
finish: closes the session, allowing the program to exit
Raises:
ValueError: Could not find a trained model in model_dir.
ValueError: if batch length of predictions are not same.
ValueError: If there is a conflict between `predict_keys` and
`predictions`. For example if `predict_keys` is not `None` but
`EstimatorSpec.predictions` is not a `dict`.
"""
hooks = _check_hooks_type(hooks)
# Check that model has been trained.
if not checkpoint_path:
checkpoint_path = saver.latest_checkpoint(self._model_dir)
if not checkpoint_path:
raise ValueError('Could not find trained model in model_dir: {}.'.format(
self._model_dir))
with ops.Graph().as_default() as g:
random_seed.set_random_seed(self._config.tf_random_seed)
training.create_global_step(g)
input_placeholder = tf.placeholder(dtype=dtype, shape=shape)
queue = tf.FIFOQueue(1, dtype, shapes=shape)
enqueue_op = queue.enqueue(input_placeholder)
features = queue.dequeue()
estimator_spec = self._call_model_fn(features, None,
model_fn_lib.ModeKeys.PREDICT)
predictions = self._extract_keys(estimator_spec.predictions, predict_keys)
mon_sess = training.MonitoredSession(
session_creator=training.ChiefSessionCreator(
checkpoint_filename_with_path=checkpoint_path,
scaffold=estimator_spec.scaffold,
config=self._session_config),
hooks=hooks)
def predict(x):
if mon_sess.should_stop():
raise StopIteration
mon_sess.run(enqueue_op, {input_placeholder: x})
preds_evaluated = mon_sess.run(predictions)
if not isinstance(predictions, dict):
return preds_evaluated
else:
preds = []
for i in range(self._extract_batch_length(preds_evaluated)):
preds.append({
key: value[i]
for key, value in six.iteritems(preds_evaluated)
})
return preds
def finish():
mon_sess.close()
return predict, finish
And here is the rough code to use it:
import tensorflow as tf
from async_estimator import AsyncEstimator
def doPrediction(model_fn, model_dir, max_seq_length):
estimator = AsyncEstimator(model_fn, model_dir=model_dir)
predict, finish = estimator.async_predictor(dtype=tf.int32, shape=(1, max_seq_length))
output = None
while True:
# my input is dependent on the previous output
x = get_numpy_data(output)
if x is None:
break
output = predict(x)
save_to_disk(output)
finish()
Note: this is a simple solution which works for my needs, it may need to be modified for other cases. It is working on TensorFlow 1.2.1.
Hopefully TF will officially adopt something like this to make serving dynamic predictions with Estimator easier.
Upvotes: 4