t0mkaka
t0mkaka

Reputation: 1851

How does the distorted_inputs() function in the TensorFlow CIFAR-10 example tutorial get 128 images per batch?

I was going through the CIFAR-10 example at TensorFlow getting started guide for CNN

Now in the train function in cifar10_train.py we get images as

images,labels = cifar10.distorted_inputs()

In the distorted_inputs() function we generate the filenames in a queue and then read a single record as

 # Create a queue that produces the filenames to read.
 filename_queue = tf.train.string_input_producer(filenames)

 # Read examples from files in the filename queue.
 read_input = cifar10_input.read_cifar10(filename_queue)
 reshaped_image = tf.cast(read_input.uint8image, tf.float32)

When I add debugging code, the read_input variable contains only 1 record with an image and its height, width, and label name.

The example then applies some distortion to the read image/record and then passes it to the _generate_image_and_label_batch() function.

This function then returns a 4D Tensor of shape [batch_size, 32, 32, 3] where batch_size = 128.

The above function utilizes the tf.train.shuffle_batch() function when returns the batch.

My question is where do the extra records come from in the tf.train.shuffle_batch() function? We are not passing it any filename or reader object.

Can someone shed some light on how we go from 1 record to 128 records? I looked into the documentation but didn't understand.

Upvotes: 8

Views: 1925

Answers (1)

mrry
mrry

Reputation: 126154

The tf.train.shuffle_batch() function can be used to produce (one or more) tensors containing a batch of inputs. Internally, tf.train.shuffle_batch() creates a tf.RandomShuffleQueue, on which it calls q.enqueue() with the image and label tensors to enqueue a single element (image-label pair). It then returns the result of q.dequeue_many(batch_size), which concatenates batch_size randomly selected elements (image-label pairs) into a batch of images and a batch of labels.

Note that, although it looks from the code like read_input and filename_queue have a functional relationship, there is an additional wrinkle. Simply evaluating the result of tf.train.shuffle_batch() will block forever, because no elements have been added to the internal queue. To simplify this, when you call tf.train.shuffle_batch(), TensorFlow will add a QueueRunner to an internal collection in the graph. A later call to tf.train.start_queue_runners() (e.g. here in cifar10_train.py) will start a thread that adds elements to the queue, and enables training to proceed. The Threading and Queues HOWTO has more information on how this works.

Upvotes: 7

Related Questions