Qianyi Zhang
Qianyi Zhang

Reputation: 184

Tensorflow: How to use the "new" Dataset API with QueueRunner

Basically I have a list of images to be processed. And I need to do some pre-processing (data augmentation) after loading, then feed to the main graph of TF. Currently I am working with a customized generator which takes a list of paths yield a pair of tensors(images) and feed to the network via placeholder. And the sequential processing took ~0.5s for each batch.

I just read the Dataset API which I could directly use by using .from_generator() function, and I could use the .get_next() as input directly.

But how does the QueueRunner fit into the framework? Does Dataset implicitly utilize queue + dequeue to maintain its generator/get_next pipeline, or it requires me to explicitly feed into a FIFOQueue afterwards? If the answer is the later one, what's the best practice to maintain the pipeline to train + validate multiple random_shuffle epochs? (I mean, how many DS/queueRunner do I need to maintain, and where do I set the shuffle and epochs?)

Upvotes: 1

Views: 1031

Answers (1)

CNugteren
CNugteren

Reputation: 890

You don't have to use the QueueRunner to have queues/buffers if you are using the Dataset API. It is possible to create queues/buffers using the Dataset API and to pre-process data and train a network concurrently. If you have a dataset, you can create a queue/buffer by either using the prefetch function or the shuffle function.

See for more information the official tutorial on the Dataset API.

Here is an example of using a prefetch buffer with pre-processing on the CPU:

 NUM_THREADS = 8
 BUFFER_SIZE = 100

 data = ...
 labels = ...
 inputs = (data, labels)

 def pre_processing(data_, labels_):
     with tf.device("/cpu:0"):
         # do some pre-processing here
         return data_, labels_

 dataset_source = tf.data.Dataset.from_tensor_slices(inputs)
 dataset = dataset_source.map(pre_processing, num_parallel_calls=NUM_THREADS)

 dataset = dataset.repeat(1)  # repeats for one epoch
 dataset = dataset.prefetch(BUFFER_SIZE)

 iterator = tf.data.Iterator.from_structure(dataset.output_types,
                                            dataset.output_shapes)
 next_element = iterator.get_next()
 init_op = iterator.make_initializer(dataset)

 with tf.Session() as sess:
     sess.run(init_op)
     while True:
         try:
             sess.run(next_element)
         except tf.errors.OutOfRangeError:
             break

Upvotes: 1

Related Questions