Reputation: 184
Basically I have a list of images to be processed. And I need to do some pre-processing (data augmentation) after loading, then feed to the main graph of TF. Currently I am working with a customized generator which takes a list of paths yield a pair of tensors(images) and feed to the network via placeholder. And the sequential processing took ~0.5s for each batch.
I just read the Dataset
API which I could directly use by using .from_generator()
function, and I could use the .get_next()
as input directly.
But how does the QueueRunner
fit into the framework? Does Dataset
implicitly utilize queue
+ dequeue
to maintain its generator/get_next
pipeline, or it requires me to explicitly feed into a FIFOQueue
afterwards? If the answer is the later one, what's the best practice to maintain the pipeline to train + validate multiple random_shuffle
epochs? (I mean, how many DS/queueRunner
do I need to maintain, and where do I set the shuffle and epochs?)
Upvotes: 1
Views: 1031
Reputation: 890
You don't have to use the QueueRunner to have queues/buffers if you are using the Dataset API. It is possible to create queues/buffers using the Dataset API and to pre-process data and train a network concurrently. If you have a dataset, you can create a queue/buffer by either using the prefetch function or the shuffle function.
See for more information the official tutorial on the Dataset API.
Here is an example of using a prefetch buffer with pre-processing on the CPU:
NUM_THREADS = 8
BUFFER_SIZE = 100
data = ...
labels = ...
inputs = (data, labels)
def pre_processing(data_, labels_):
with tf.device("/cpu:0"):
# do some pre-processing here
return data_, labels_
dataset_source = tf.data.Dataset.from_tensor_slices(inputs)
dataset = dataset_source.map(pre_processing, num_parallel_calls=NUM_THREADS)
dataset = dataset.repeat(1) # repeats for one epoch
dataset = dataset.prefetch(BUFFER_SIZE)
iterator = tf.data.Iterator.from_structure(dataset.output_types,
dataset.output_shapes)
next_element = iterator.get_next()
init_op = iterator.make_initializer(dataset)
with tf.Session() as sess:
sess.run(init_op)
while True:
try:
sess.run(next_element)
except tf.errors.OutOfRangeError:
break
Upvotes: 1