GunnarTrollson
GunnarTrollson

Reputation: 61

How to read (decode) tfrecords with tf.data API

I have a custom dataset, that I then stored as tfrecord, doing

# toy example data
label = np.asarray([[1,2,3],
                    [4,5,6]]).reshape(2, 3, -1)

sample = np.stack((label + 200).reshape(2, 3, -1))

def bytes_feature(values):
    """Returns a TF-Feature of bytes.
    Args:
    values: A string.
    Returns:
    A TF-Feature.
    """
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))


def labeled_image_to_tfexample(sample_binary_string, label_binary_string):
    return tf.train.Example(features=tf.train.Features(feature={
      'sample/image': bytes_feature(sample_binary_string),
      'sample/label': bytes_feature(label_binary_string)
    }))


def _write_to_tf_record():
    with tf.Graph().as_default():
        image_placeholder = tf.placeholder(dtype=tf.uint16)
        encoded_image = tf.image.encode_png(image_placeholder)

        label_placeholder = tf.placeholder(dtype=tf.uint16)
        encoded_label = tf.image.encode_png(image_placeholder)

        with tf.python_io.TFRecordWriter("./toy.tfrecord") as writer:
            with tf.Session() as sess:
                feed_dict = {image_placeholder: sample,
                             label_placeholder: label}

                # Encode image and label as binary strings to be written to tf_record
                image_string, label_string = sess.run(fetches=(encoded_image, encoded_label),
                                                      feed_dict=feed_dict)

                # Define structure of what is going to be written
                file_structure = labeled_image_to_tfexample(image_string, label_string)

                writer.write(file_structure.SerializeToString())
                return

However I cannot read it. First I tried (based on http://www.machinelearninguru.com/deep_learning/tensorflow/basics/tfrecord/tfrecord.html , https://medium.com/coinmonks/storage-efficient-tfrecord-for-images-6dc322b81db4 and https://medium.com/mostly-ai/tensorflow-records-what-they-are-and-how-to-use-them-c46bc4bbb564)

def read_tfrecord_low_level():
    data_path = "./toy.tfrecord"
    filename_queue = tf.train.string_input_producer([data_path], num_epochs=1)
    reader = tf.TFRecordReader()
    _, raw_records = reader.read(filename_queue)

    decode_protocol = {
        'sample/image': tf.FixedLenFeature((), tf.int64),
        'sample/label': tf.FixedLenFeature((), tf.int64)
    }
    enc_example = tf.parse_single_example(raw_records, features=decode_protocol)
    recovered_image = enc_example["sample/image"]
    recovered_label = enc_example["sample/label"]

    return recovered_image, recovered_label

I also tried variations casting enc_example and decoding it, such as in Unable to read from Tensorflow tfrecord file However when I try to evaluate them my python session just freezes and gives no output or traceback.

Then I tried using eager execution to see what is happening, but apparently it is only compatible with tf.data API. However as far as I understand transformations on tf.data API are made on the whole dataset. https://www.tensorflow.org/api_guides/python/reading_data mentions that a decode function must be written, but doesn't give an example on how to do that. All the tutorials I have found are made for TFRecordReader (which doesn't work for me).

Any help (pinpointing what I am doing wrong/ explaining what is happening/ indications on how to decode tfrecords with tf.data API) is highly appreciated.

According to https://www.youtube.com/watch?v=4oNdaQk0Qv4 and https://www.youtube.com/watch?v=uIcqeP7MFH0 tf.data is the best way to create input pipelines, so I am highly interested on learning that way.

Thanks in advance!

Upvotes: 3

Views: 4846

Answers (1)

Walfits
Walfits

Reputation: 446

I am not sure why storing the encoded png causes the evaluation to not work, but here is a possible way of working around the problem. Since you mentioned that you would like to use the tf.data way of creating input pipelines, I'll show how to use it with your toy example:

label = np.asarray([[1,2,3],
                [4,5,6]]).reshape(2, 3, -1)

sample = np.stack((label + 200).reshape(2, 3, -1))

First, the data has to be saved to the TFRecord file. The difference from what you did is that the image is not encoded to png.

def _bytes_feature(value):
     return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

writer = tf.python_io.TFRecordWriter("toy.tfrecord")

example = tf.train.Example(features=tf.train.Features(feature={
            'label_raw': _bytes_feature(tf.compat.as_bytes(label.tostring())),
             'sample_raw': _bytes_feature(tf.compat.as_bytes(sample.tostring()))}))

writer.write(example.SerializeToString())

writer.close()

What happens in the code above is that the arrays are turned into strings (1d objects) and then stored as bytes features.

Then, to read the data back using the tf.data.TFRecordDataset and tf.data.Iterator class:

filename = 'toy.tfrecord'

# Create a placeholder that will contain the name of the TFRecord file to use
data_path = tf.placeholder(dtype=tf.string, name="tfrecord_file")

# Create the dataset from the TFRecord file
dataset = tf.data.TFRecordDataset(data_path)

# Use the map function to read every sample from the TFRecord file (_read_from_tfrecord is shown below)
dataset = dataset.map(_read_from_tfrecord)

# Create an iterator object that enables you to access all the samples in the dataset
iterator = tf.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes)
label_tf, sample_tf = iterator.get_next()

# Similarly to tf.Variables, the iterators have to be initialised
iterator_init = iterator.make_initializer(dataset, name="dataset_init")

with tf.Session() as sess:
    # Initialise the iterator passing the name of the TFRecord file to the placeholder
    sess.run(iterator_init, feed_dict={data_path: filename})

    # Obtain the images and labels back
    read_label, read_sample = sess.run([label_tf, sample_tf])

The function _read_from_tfrecord() is:

def _read_from_tfrecord(example_proto):
        feature = {
            'label_raw': tf.FixedLenFeature([], tf.string),
            'sample_raw': tf.FixedLenFeature([], tf.string)
        }

    features = tf.parse_example([example_proto], features=feature)

    # Since the arrays were stored as strings, they are now 1d 
    label_1d = tf.decode_raw(features['label_raw'], tf.int64)
    sample_1d = tf.decode_raw(features['sample_raw'], tf.int64)

    # In order to make the arrays in their original shape, they have to be reshaped.
    label_restored = tf.reshape(label_1d, tf.stack([2, 3, -1]))
    sample_restored = tf.reshape(sample_1d, tf.stack([2, 3, -1]))

    return label_restored, sample_restored

Instead of hard-coding the shape [2, 3, -1], you could also store that too into the TFRecord file, but for simplicity I didn't do it.

I made a little gist with a working example.

Hope this helps!

Upvotes: 4

Related Questions