Reputation: 824
I have an input pipeline where samples are generated on fly. I use keras and custom ImageDataGenerator and corresponding Iterator to get samples in memory. Under assumption that keras in my setup is using feed_dict (and that assumption is a question to me) I am thinking of speeding things up by switching to raw tensorflow + Dataset.from_generator().
Here I see that suggested solution for input pipelines that generate data on fly in the most recent Tensorflow is to use Dataset.from_generator().
Questions:
Upvotes: 2
Views: 932
Reputation: 126194
The new tf.contrib.data.Dataset.from_generator()
can potentially speed up your input pipeline by overlapping the data preparation with training. However, you will tend to get the best performance by switching over to TensorFlow ops in your input pipeline wherever possible.
To answer your specific questions:
The Keras TensorFlow backend uses tf.placeholder()
to represent compiled function inputs, and feed_dict
to pass arguments to a function.
With the recent optimizations to tf.py_func()
and feed_dict
copy overhead, I suspect the amount of time spent in memcpy()
will be the same. However, you can more easily use Dataset.from_generator()
with Dataset.prefetch()
to overlap the training on one batch with preprocessing on the next batch.
It sounds like you can define a separate iterator for the prediction phase. The tf.estimator.Estimator
class does something similar by instantiating different "input functions" with different signatures for training and evaluation, then building a separate graph for each role.
Alternatively, you could add a dummy output to your training iterator (for the batch_z
values) and switch between training and evaluation iterators using a "feedable iterator".
Upvotes: 6