Reputation: 1851
I was going through the CIFAR-10 example at TensorFlow getting started guide for CNN
Now in the train function in cifar10_train.py we get images as
images,labels = cifar10.distorted_inputs()
In the distorted_inputs()
function we generate the filenames in a queue and then read a single record as
# Create a queue that produces the filenames to read.
filename_queue = tf.train.string_input_producer(filenames)
# Read examples from files in the filename queue.
read_input = cifar10_input.read_cifar10(filename_queue)
reshaped_image = tf.cast(read_input.uint8image, tf.float32)
When I add debugging code, the read_input
variable contains only 1 record with an image and its height, width, and label name.
The example then applies some distortion to the read image/record and then passes it to the _generate_image_and_label_batch()
function.
This function then returns a 4D Tensor of shape [batch_size, 32, 32, 3]
where batch_size = 128
.
The above function utilizes the tf.train.shuffle_batch()
function when returns the batch.
My question is where do the extra records come from in the tf.train.shuffle_batch()
function? We are not passing it any filename or reader object.
Can someone shed some light on how we go from 1 record to 128 records? I looked into the documentation but didn't understand.
Upvotes: 8
Views: 1925
Reputation: 126154
The tf.train.shuffle_batch()
function can be used to produce (one or more) tensors containing a batch of inputs. Internally, tf.train.shuffle_batch()
creates a tf.RandomShuffleQueue
, on which it calls q.enqueue()
with the image and label tensors to enqueue a single element (image-label pair). It then returns the result of q.dequeue_many(batch_size)
, which concatenates batch_size
randomly selected elements (image-label pairs) into a batch of images and a batch of labels.
Note that, although it looks from the code like read_input
and filename_queue
have a functional relationship, there is an additional wrinkle. Simply evaluating the result of tf.train.shuffle_batch()
will block forever, because no elements have been added to the internal queue. To simplify this, when you call tf.train.shuffle_batch()
, TensorFlow will add a QueueRunner
to an internal collection in the graph. A later call to tf.train.start_queue_runners()
(e.g. here in cifar10_train.py
) will start a thread that adds elements to the queue, and enables training to proceed. The Threading and Queues HOWTO has more information on how this works.
Upvotes: 7