Gabriel Perdue
Gabriel Perdue

Reputation: 1593

`tf.train.shuffle_batch` crashes when reading `TFRecord` files in TensorFlow

I am trying to use tf.train.shuffle_batch to consume batches of data from a TFRecord file using TensorFlow 1.0. The relevant functions are:

def tfrecord_to_graph_ops(filenames_list):
    file_queue = tf.train.string_input_producer(filenames_list)
    reader = tf.TFRecordReader()
    _, tfrecord = reader.read(file_queue)

    tfrecord_features = tf.parse_single_example(
        tfrecord,
        features={'targets': tf.FixedLenFeature([], tf.string)}
    )
    ## if no reshaping: `ValueError: All shapes must be fully defined` in
    ## `tf.train.shuffle_batch`
    targets = tf.decode_raw(tfrecord_features['targets'], tf.uint8)
    ## if using `strided_slice`, always get the first record
    # targets = tf.cast(
    #     tf.strided_slice(targets, [0], [1]),
    #     tf.int32
    # )
    ## error on shapes being fully defined
    # targets = tf.reshape(targets, [])
    ## get us: Invalid argument: Shape mismatch in tuple component 0.
    ## Expected [1], got [1000]
    targets.set_shape([1])
    return targets


def batch_generator(filenames_list, batch_size=BATCH_SIZE):
    targets = tfrecord_to_graph_ops(filenames_list)
    targets_batch = tf.train.shuffle_batch(
        [targets],
        batch_size=batch_size,
        capacity=(20 * batch_size),
        min_after_dequeue=(2 * batch_size)
    )
    targets_batch = tf.one_hot(
        indices=targets_batch, depth=10, on_value=1, off_value=0
    )
    return targets_batch


def examine_batches(targets_batch):
    with tf.Session() as sess:
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        for _ in range(10):
            targets = sess.run([targets_batch])
            print(targets)
        coord.request_stop()
        coord.join(threads)

The code enters through examine_batches(), having been handed the output of batch_generator(). batch_generator() calls tfrecord_to_graph_ops() and the problem is in that function, I believe.

I am calling

targets = tf.decode_raw(tfrecord_features['targets'], tf.uint8)

on a file with 1,000 bytes (numbers 0-9). If I call eval() on this in a Session, it shows me all 1,000 elements. But if I try to put it in a batch generator, it crashes.

If I don't reshape targets, I get an error like ValueError: All shapes must be fully defined when tf.train.shuffle_batch is called. If I call targets.set_shape([1]), reminiscent of Google's CIFAR-10 example code, I get an error like Invalid argument: Shape mismatch in tuple component 0. Expected [1], got [1000] in tf.train.shuffle_batch. I also tried using tf.strided_slice to cut a chunk of the raw data - this doesn't crash but it results in just getting the first event over and over again.

What is the right way to do this? To pull batches from a TFRecord file?

Note, I could manually write a function that chopped up the raw byte data and did some sort of batching - especially easy if I am using the feed_dict approach to getting data into the graph - but I am trying to learn how to use TensorFlow's TFRecord files and how to use their built in batching functions.

Thanks!

Upvotes: 1

Views: 449

Answers (1)

Gabriel Perdue
Gabriel Perdue

Reputation: 1593

Allen Lavoie pointed out the correct solution in a comment. The important missing piece was enqueue_many=True as an argument to tf.train.shuffle_batch(). The correct way to write those functions is:

def tfrecord_to_graph_ops(filenames_list):
    file_queue = tf.train.string_input_producer(filenames_list)
    reader = tf.TFRecordReader()
    _, tfrecord = reader.read(file_queue)

    tfrecord_features = tf.parse_single_example(
        tfrecord,
        features={'targets': tf.FixedLenFeature([], tf.string)}
    )
    targets = tf.decode_raw(tfrecord_features['targets'], tf.uint8)
    targets = tf.reshape(targets, [-1])
    return targets

def batch_generator(filenames_list, batch_size=BATCH_SIZE):
    targets = tfrecord_to_graph_ops(filenames_list)
    targets_batch = tf.train.shuffle_batch(
        [targets],
        batch_size=batch_size,
        capacity=(20 * batch_size),
        min_after_dequeue=(2 * batch_size),
        enqueue_many=True
    )
    return targets_batch

Upvotes: 1

Related Questions