SanityIO
SanityIO

Reputation: 824

Tensorflow input pipeline

I have an input pipeline where samples are generated on fly. I use keras and custom ImageDataGenerator and corresponding Iterator to get samples in memory. Under assumption that keras in my setup is using feed_dict (and that assumption is a question to me) I am thinking of speeding things up by switching to raw tensorflow + Dataset.from_generator().

Here I see that suggested solution for input pipelines that generate data on fly in the most recent Tensorflow is to use Dataset.from_generator().

Questions:

  1. Does keras with Tensorflow backend use feed_dict method?
  2. If I switch to raw tensorflow + Dataset.from_generator(my_sample_generator) will that cut feed_dict memory copy overhead and buy me performance?
  3. During predict (evaluation) phase apart from batch_x, batch_y I have also opaque index vector from my generator output. That vector corresponds to sample ids in the batch_x. Does that mean that I'm stuck with feed_dict approach for predict phase because I need that extra batch_z output from iterator?

Upvotes: 2

Views: 932

Answers (1)

mrry
mrry

Reputation: 126194

The new tf.contrib.data.Dataset.from_generator() can potentially speed up your input pipeline by overlapping the data preparation with training. However, you will tend to get the best performance by switching over to TensorFlow ops in your input pipeline wherever possible.

To answer your specific questions:

  1. The Keras TensorFlow backend uses tf.placeholder() to represent compiled function inputs, and feed_dict to pass arguments to a function.

  2. With the recent optimizations to tf.py_func() and feed_dict copy overhead, I suspect the amount of time spent in memcpy() will be the same. However, you can more easily use Dataset.from_generator() with Dataset.prefetch() to overlap the training on one batch with preprocessing on the next batch.

  3. It sounds like you can define a separate iterator for the prediction phase. The tf.estimator.Estimator class does something similar by instantiating different "input functions" with different signatures for training and evaluation, then building a separate graph for each role.

    Alternatively, you could add a dummy output to your training iterator (for the batch_z values) and switch between training and evaluation iterators using a "feedable iterator".

Upvotes: 6

Related Questions