golmschenk
golmschenk

Reputation: 12384

Reading sequential data from TFRecords files within the TensorFlow graph?

I'm working with video data, but I believe this question should apply to any sequential data. I want to pass my RNN 10 sequential examples (video frames) from a TFRecords file. When I first start reading the file, I need to grab 10 examples, and use this to create a sequence-example which is then pushed onto the queue for the RNN to take when it's ready. However, now that I have the 10 frames, next time I read from the TFRecords file, I only need to take 1 example and just shift the other 9 over. But when I hit the end of the first TFRecords file, I need to restart the process on the second TFRecords file. It's my understanding that the cond op will process the ops required under each condition even if that condition is not the one that is to be used. This would be a problem when using a condition to check whether to read 10 examples or only 1. Is there anyway to resolve this problem to still have the desired result outlined above?

Upvotes: 1

Views: 1125

Answers (1)

mrry
mrry

Reputation: 126154

You can use the recently added Dataset.window() transformation in TensorFlow 1.12 to do this:

filenames = tf.data.Dataset.list_files(...)

# Define a function that will be applied to each filename, and return the sequences in that
# file.
def get_examples_from_file(filename):
  # Read and parse the examples from the file using the appropriate logic.
  examples = tf.data.TFRecordDataset(filename).map(...)

  # Selects a sliding window of 10 examples, shifting along 1 example at a time.
  sequences = examples.window(size=10, shift=1, drop_remainder=True)

  # Each element of `sequences` is a nested dataset containing 10 consecutive examples.
  # Use `Dataset.batch()` and get the resulting tensor to convert it to a tensor value
  # (or values, if there are multiple features in an example).
  return sequences.map(
      lambda d: tf.data.experimental.get_single_element(d.batch(10)))

# Alternatively, you can use `filenames.interleave()` to mix together sequences from
# different files.
sequences = filenames.flat_map(get_examples_from_file)

Upvotes: 2

Related Questions