Clock Slave
Clock Slave

Reputation: 7967

Tensorflow - Next batch of data from tf.train.shuffle_batch

I have a tfrecords file from which I am looking to create batches of data. I am using tf.train.shuffle_batch() to create a single batch. In my training I want to call the batches and pass them. And this is where I am stuck. I read that the the poistion of the TFRecordReader() gets saved in the state of the graph and the next example is read from the subsequent position. The trouble is I'm not able to figure how do I load the next batch. I'm using the below code to create the batches.

def read_and_decode_single_example(filename):
    filename_queue = tf.train.string_input_producer([filename], num_epochs=1)
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(
        serialized_example,
        features={
            'context': tf.FixedLenFeature([160], tf.int64),
            'context_len': tf.FixedLenFeature([1], tf.int64),
            'utterance': tf.FixedLenFeature([160], tf.int64),
            'utterance_len': tf.FixedLenFeature([1], tf.int64),
            'label': tf.FixedLenFeature([1], tf.int64)
        })

    contexts = features['context']
    context_lens = features['context_len']
    utterances = features['utterance']
    utterance_lens = features['utterance_len']
    labels = features['label']

    return contexts, context_lens, utterances, utterance_lens, labels

contexts, context_lens, utterances, utterance_lens, labels = \
    read_and_decode_single_example('data/train.tfrecords')

contexts_batch, context_lens_batch, \
    utterances_batch, utterance_lens_batch, \
    labels_batch = tf.train.shuffle_batch([contexts, context_lens, utterances,
                                          utterance_lens, labels],
                                          batch_size=batch_size,
                                          capacity=3*batch_size,
                                          min_after_dequeue=batch_size)

This gives me a single batch of data. I want to use the feed_dict paradigm to pass the batches for the training wherein on each iteration, a new batch gets passed in. How do I load these batches? Will calling the read_and_decode along with tf.train.shuffle_batch again call the next batch?

Upvotes: 1

Views: 1192

Answers (1)

sunside
sunside

Reputation: 8249

The read_and_decode_single_example() function creates a (sub-)graph for the network that is used to load data; you only call that once. It might be more appropriately called build_read_and_decode_single_example_graph(), but that's a bit long.

The "magic" lies in evaluating (i.e. using) the _batch tensors multiple times, e.g.

batch_size = 100
# ...

with tf.Session() as sess:
    # get the first batch of 100 values
    first_batch = sess.run([contexts_batch, context_lens_batch,
                            utterances_batch, utterance_lens_batch,
                            labels_batch])

    # second batch of different 100 values
    second_batch = sess.run([contexts_batch, context_lens_batch,
                            utterances_batch, utterance_lens_batch,
                            labels_batch])
    # etc.

Of course, rather than fetching these values from a session manually, you would feed them into some other part of the network instead. The mechanism is the same: Whenever one of these tensors is fetched directly or indirectly, the batching mechanism will take care of providing you a new batch (of different values) each time.

Upvotes: 1

Related Questions