Reputation: 7967
I have a tfrecords
file from which I am looking to create batches of data. I am using tf.train.shuffle_batch()
to create a single batch. In my training I want to call the batches and pass them. And this is where I am stuck. I read that the the poistion of the TFRecordReader()
gets saved in the state of the graph and the next example is read from the subsequent position. The trouble is I'm not able to figure how do I load the next batch. I'm using the below code to create the batches.
def read_and_decode_single_example(filename):
filename_queue = tf.train.string_input_producer([filename], num_epochs=1)
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
features={
'context': tf.FixedLenFeature([160], tf.int64),
'context_len': tf.FixedLenFeature([1], tf.int64),
'utterance': tf.FixedLenFeature([160], tf.int64),
'utterance_len': tf.FixedLenFeature([1], tf.int64),
'label': tf.FixedLenFeature([1], tf.int64)
})
contexts = features['context']
context_lens = features['context_len']
utterances = features['utterance']
utterance_lens = features['utterance_len']
labels = features['label']
return contexts, context_lens, utterances, utterance_lens, labels
contexts, context_lens, utterances, utterance_lens, labels = \
read_and_decode_single_example('data/train.tfrecords')
contexts_batch, context_lens_batch, \
utterances_batch, utterance_lens_batch, \
labels_batch = tf.train.shuffle_batch([contexts, context_lens, utterances,
utterance_lens, labels],
batch_size=batch_size,
capacity=3*batch_size,
min_after_dequeue=batch_size)
This gives me a single batch of data. I want to use the feed_dict
paradigm to pass the batches for the training wherein on each iteration, a new batch gets passed in. How do I load these batches? Will calling the read_and_decode
along with tf.train.shuffle_batch
again call the next batch?
Upvotes: 1
Views: 1192
Reputation: 8249
The read_and_decode_single_example()
function creates a (sub-)graph for the network that is used to load data; you only call that once. It might be more appropriately called build_read_and_decode_single_example_graph()
, but that's a bit long.
The "magic" lies in evaluating (i.e. using) the _batch
tensors multiple times, e.g.
batch_size = 100
# ...
with tf.Session() as sess:
# get the first batch of 100 values
first_batch = sess.run([contexts_batch, context_lens_batch,
utterances_batch, utterance_lens_batch,
labels_batch])
# second batch of different 100 values
second_batch = sess.run([contexts_batch, context_lens_batch,
utterances_batch, utterance_lens_batch,
labels_batch])
# etc.
Of course, rather than fetching these values from a session manually, you would feed them into some other part of the network instead. The mechanism is the same: Whenever one of these tensors is fetched directly or indirectly, the batching mechanism will take care of providing you a new batch (of different values) each time.
Upvotes: 1