Colin Skow
Colin Skow

Reputation: 1006

How to run asynchronous predictions with TensorFlow Estimator API?

I am using the tf.estimator API to predict punctuation. I trained it with pre-processed data using TFRecords and tf.train.shuffle_batch. Now I want to make predictions. I can do this fine feeding static NumPy data into tf.constant and returning this from the input_fn.

However I am working with sequence data and I need to feed one example at a time and the next input is dependent on the previous output. I also want to be able to process data input through HTTP requests.

Every time estimator.predict is called it re-loads the checkpoint and recreates the entire graph. This is slow and expensive. So I need to be able to dynamically feed data to the input_fn.

My current attempt is roughly this:

feature_input = tf.placeholder(tf.int32, shape=[1, MAX_SUBSEQUENCE_LEN])
q = tf.FIFOQueue(1, tf.int32, shapes=[[1, MAX_SUBSEQUENCE_LEN]])
enqueue_op = q.enqueue(feature_input)

def input_fn():
    return q.dequeue()

estimator = tf.estimator.Estimator(model_fn, model_dir=model_file)
predictor = estimator.predict(input_fn=input_fn)
sess = tf.Session()
output = None

while True:
    x = get_numpy_data(x, output)
    if x is None:
        break
    sess.run(enqueue_op, {feature_input: x})
    output = predictor.next()
    save_to_file(output)

sess.close()

However I am getting the following error: ValueError: Input graph and Layer graph are not the same: Tensor("EmbedSequence/embedding_lookup:0", shape=(1, 200, 128), dtype=float32) is not from the passed-in graph.

How can I asynchronously plug data into my existing graph through an input_fn to get predictions one at a time?

Upvotes: 3

Views: 3337

Answers (1)

Colin Skow
Colin Skow

Reputation: 1006

It turns out the main problem is that all tensors need to be created inside the input_fn or they don't get added to the same graph. I needed to run an enqueue operation but it was impossible to access anything returned from the input function.

I ended up inheriting the Estimator class and creating a custom predict function which allows me to dynamically add data to the prediction queue and return the results:

# async_estimator.py

import six
import tensorflow as tf
from tensorflow.python.estimator.estimator import Estimator
from tensorflow.python.estimator.estimator import _check_hooks_type
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.training import saver
from tensorflow.python.training import training


class AsyncEstimator(Estimator):

    def async_predictor(self,
                dtype,
                shape=None,
                predict_keys=None,
                hooks=None,
                checkpoint_path=None):
        """Returns a tuple of functions: first runs predicitons on the model, second cleans up
        Args:
          dtype: the dtype of the input
          shape: the shape of the input placeholder (optional)
          predict_keys: list of `str`, name of the keys to predict. It is used if
            the `EstimatorSpec.predictions` is a `dict`. If `predict_keys` is used
            then rest of the predictions will be filtered from the dictionary. If
            `None`, returns all.
          hooks: List of `SessionRunHook` subclass instances. Used for callbacks
            inside the prediction call.
          checkpoint_path: Path of a specific checkpoint to predict. If `None`, the
            latest checkpoint in `model_dir` is used.
        Returns:
          (predict, finish): tuple of functions

            predict: runs a single prediction and returns the results
                Args:
                    x: NumPy array of input
                Returns:
                    Evaluated value of the prediction

            finish: closes the session, allowing the program to exit

        Raises:
          ValueError: Could not find a trained model in model_dir.
          ValueError: if batch length of predictions are not same.
          ValueError: If there is a conflict between `predict_keys` and
            `predictions`. For example if `predict_keys` is not `None` but
            `EstimatorSpec.predictions` is not a `dict`.
        """
        hooks = _check_hooks_type(hooks)
        # Check that model has been trained.
        if not checkpoint_path:
            checkpoint_path = saver.latest_checkpoint(self._model_dir)
        if not checkpoint_path:
            raise ValueError('Could not find trained model in model_dir: {}.'.format(
                self._model_dir))

        with ops.Graph().as_default() as g:
            random_seed.set_random_seed(self._config.tf_random_seed)
            training.create_global_step(g)
            input_placeholder = tf.placeholder(dtype=dtype, shape=shape)
            queue = tf.FIFOQueue(1, dtype, shapes=shape)
            enqueue_op = queue.enqueue(input_placeholder)
            features = queue.dequeue()
            estimator_spec = self._call_model_fn(features, None,
                                                 model_fn_lib.ModeKeys.PREDICT)
            predictions = self._extract_keys(estimator_spec.predictions, predict_keys)
            mon_sess = training.MonitoredSession(
                    session_creator=training.ChiefSessionCreator(
                        checkpoint_filename_with_path=checkpoint_path,
                        scaffold=estimator_spec.scaffold,
                        config=self._session_config),
                    hooks=hooks)

            def predict(x):
                if mon_sess.should_stop():
                    raise StopIteration
                mon_sess.run(enqueue_op, {input_placeholder: x})
                preds_evaluated = mon_sess.run(predictions)
                if not isinstance(predictions, dict):
                    return preds_evaluated
                else:
                    preds = []
                    for i in range(self._extract_batch_length(preds_evaluated)):
                        preds.append({
                            key: value[i]
                            for key, value in six.iteritems(preds_evaluated)
                        })
                    return preds

            def finish():
                mon_sess.close()

            return predict, finish

And here is the rough code to use it:

import tensorflow as tf
from async_estimator import AsyncEstimator


def doPrediction(model_fn, model_dir, max_seq_length):
    estimator = AsyncEstimator(model_fn, model_dir=model_dir)
    predict, finish = estimator.async_predictor(dtype=tf.int32, shape=(1, max_seq_length))
    output = None

    while True:
        # my input is dependent on the previous output
        x = get_numpy_data(output)
        if x is None:
            break
        output = predict(x)
        save_to_disk(output)

    finish()

Note: this is a simple solution which works for my needs, it may need to be modified for other cases. It is working on TensorFlow 1.2.1.

Hopefully TF will officially adopt something like this to make serving dynamic predictions with Estimator easier.

Upvotes: 4

Related Questions