Gabriel Perdue
Gabriel Perdue

Reputation: 1593

Read TFRecord image data with new TensorFlow Dataset API

I am having trouble reading TFRecord format image data using the "new" (TensorFlow v1.4) Dataset API. I believe the problem is that I am somehow consuming the whole dataset instead of a single batch when trying to read. I have a working example of doing this using the batch/file-queue API here: https://github.com/gnperdue/TFExperiments/tree/master/conv (well, in the example I am running a classifier, but the code to read the TFRecord images is in the DataReaders.py class).

The problem functions are, I believe, these:

def parse_mnist_tfrec(tfrecord, features_shape):
    tfrecord_features = tf.parse_single_example(
        tfrecord,
        features={
            'features': tf.FixedLenFeature([], tf.string),
            'targets': tf.FixedLenFeature([], tf.string)
        }
    )
    features = tf.decode_raw(tfrecord_features['features'], tf.uint8)
    features = tf.reshape(features, features_shape)
    features = tf.cast(features, tf.float32)
    targets = tf.decode_raw(tfrecord_features['targets'], tf.uint8)
    targets = tf.one_hot(indices=targets, depth=10, on_value=1, off_value=0)
    targets = tf.cast(targets, tf.float32)
    return features, targets

class MNISTDataReaderDset:
    def __init__(self, data_reader_dict):
        # doesn't matter here

    def batch_generator(self, num_epochs=1):
        def parse_fn(tfrecord):
            return parse_mnist_tfrec(
                tfrecord, self.name, self.features_shape
            )
        dataset = tf.data.TFRecordDataset(
            self.filenames_list, compression_type=self.compression_type
        )
        dataset = dataset.map(parse_fn)
        dataset = dataset.repeat(num_epochs)
        dataset = dataset.batch(self.batch_size)
        iterator = dataset.make_one_shot_iterator()
        batch_features, batch_labels = iterator.get_next()
        return batch_features, batch_labels

Then, in use:

        batch_features, batch_labels = \
            data_reader.batch_generator(num_epochs=1)

        sess.run(tf.local_variables_initializer())
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        try:
            # look at 3 batches only
            for _ in range(3):
                labels, feats = sess.run([
                    batch_labels, batch_features
                ])

This generates an error like:

 [[Node: Reshape_1 = Reshape[T=DT_UINT8, Tshape=DT_INT32](DecodeRaw_1, Reshape_1/shape)]]
 Input to reshape is a tensor with 50000 values, but the requested shape has 1
 [[Node: Reshape_1 = Reshape[T=DT_UINT8, Tshape=DT_INT32](DecodeRaw_1, Reshape_1/shape)]]
 [[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[?,28,28,1], [?,10]], output_types=[DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](OneShotIterator)]]

Does anyone have any ideas?

I have a gist with the full code in the reader example and a link to the TFRecord files (our old, good friend MNIST, in TFRecord form) here:

https://gist.github.com/gnperdue/56092626d611ae23370a21fdeeb2abe8

Thanks!

Edit - I also tried a flat_map, e.g.:

def batch_generator(self, num_epochs=1):
    """
    TODO - we can use placeholders for the list of file names and
    init with a feed_dict when we call `sess.run` - give this a
    try with one list for training and one for validation
    """
    def parse_fn(tfrecord):
        return parse_mnist_tfrec(
            tfrecord, self.name, self.features_shape
        )
    dataset = tf.data.Dataset.from_tensor_slices(self.filenames_list)
    dataset = dataset.flat_map(
        lambda filename: (
            tf.data.TFRecordDataset(
                filename, compression_type=self.compression_type
            ).map(parse_fn).batch(self.batch_size)
        )
    )
    dataset = dataset.repeat(num_epochs)
    iterator = dataset.make_one_shot_iterator()
    batch_features, batch_labels = iterator.get_next()
    return batch_features, batch_labels

I also tried using just one file and not a list (in my first way of approaching this above). No matter what, it seems TF always wants to eat the entire file into the TFRecordDataset and won't operate on single records.

Upvotes: 4

Views: 1345

Answers (1)

Gabriel Perdue
Gabriel Perdue

Reputation: 1593

Okay, I figured this out - the code above is fine. The problem was my script for creating the TFRecords. Basically, I had a block like this

def write_tfrecord(reader, start_idx, stop_idx, tfrecord_file):
    writer = tf.python_io.TFRecordWriter(tfrecord_file)
    tfeat, ttarg = get_binary_data(reader, start_idx, stop_idx)
    example = tf.train.Example(
        features=tf.train.Features(
            feature={
                'features': tf.train.Feature(
                    bytes_list=tf.train.BytesList(value=[tfeat])
                ),
                'targets': tf.train.Feature(
                    bytes_list=tf.train.BytesList(value=[ttarg])
                )
            }
        )
    )
    writer.write(example.SerializeToString())
    writer.close()

and I needed a block like this instead:

def write_tfrecord(reader, start_idx, stop_idx, tfrecord_file):
    writer = tf.python_io.TFRecordWriter(tfrecord_file)
    for idx in range(start_idx, stop_idx):
        tfeat, ttarg = get_binary_data(reader, idx)
        example = tf.train.Example(
            features=tf.train.Features(
                feature={
                    'features': tf.train.Feature(
                        bytes_list=tf.train.BytesList(value=[tfeat])
                    ),
                    'targets': tf.train.Feature(
                        bytes_list=tf.train.BytesList(value=[ttarg])
                    )
                }
            )
        )
        writer.write(example.SerializeToString())
    writer.close()

Which is to say - I was basically writing my entire block of data as one giant TFRecord when I needed be making one per example in the data.

It turns out if you do it either way in the old file and batch-queue API everything works - the functions like tf.train.batch are auto-magically 'smart' enough to either carve the big block up or concatenate lots of single-example records into a batch depending on what you give it. When I fixed my code that made the TFRecords file, I didn't need to change anything in my old file and batch-queue code and it still consumed the TFRecords file just fine. However, the Dataset API is sensitive to this difference. That is why in my code above it always appeared to be consuming the entire file - its because the entire file really was one big TFRecord.

Upvotes: 1

Related Questions