Reputation: 1593
I am trying to use tf.train.shuffle_batch
to consume batches of data from a TFRecord
file using TensorFlow 1.0. The relevant functions are:
def tfrecord_to_graph_ops(filenames_list):
file_queue = tf.train.string_input_producer(filenames_list)
reader = tf.TFRecordReader()
_, tfrecord = reader.read(file_queue)
tfrecord_features = tf.parse_single_example(
tfrecord,
features={'targets': tf.FixedLenFeature([], tf.string)}
)
## if no reshaping: `ValueError: All shapes must be fully defined` in
## `tf.train.shuffle_batch`
targets = tf.decode_raw(tfrecord_features['targets'], tf.uint8)
## if using `strided_slice`, always get the first record
# targets = tf.cast(
# tf.strided_slice(targets, [0], [1]),
# tf.int32
# )
## error on shapes being fully defined
# targets = tf.reshape(targets, [])
## get us: Invalid argument: Shape mismatch in tuple component 0.
## Expected [1], got [1000]
targets.set_shape([1])
return targets
def batch_generator(filenames_list, batch_size=BATCH_SIZE):
targets = tfrecord_to_graph_ops(filenames_list)
targets_batch = tf.train.shuffle_batch(
[targets],
batch_size=batch_size,
capacity=(20 * batch_size),
min_after_dequeue=(2 * batch_size)
)
targets_batch = tf.one_hot(
indices=targets_batch, depth=10, on_value=1, off_value=0
)
return targets_batch
def examine_batches(targets_batch):
with tf.Session() as sess:
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for _ in range(10):
targets = sess.run([targets_batch])
print(targets)
coord.request_stop()
coord.join(threads)
The code enters through examine_batches()
, having been handed the output of batch_generator()
. batch_generator()
calls tfrecord_to_graph_ops()
and the problem is in that function, I believe.
I am calling
targets = tf.decode_raw(tfrecord_features['targets'], tf.uint8)
on a file with 1,000 bytes (numbers 0-9). If I call eval()
on this in a Session, it shows me all 1,000 elements. But if I try to put it in a batch generator, it crashes.
If I don't reshape targets
, I get an error like ValueError: All shapes must be fully defined
when tf.train.shuffle_batch
is called. If I call targets.set_shape([1])
, reminiscent of Google's CIFAR-10 example code, I get an error like Invalid argument: Shape mismatch in tuple component 0. Expected [1], got [1000]
in tf.train.shuffle_batch
. I also tried using tf.strided_slice
to cut a chunk of the raw data - this doesn't crash but it results in just getting the first event over and over again.
What is the right way to do this? To pull batches from a TFRecord
file?
Note, I could manually write a function that chopped up the raw byte data and did some sort of batching - especially easy if I am using the feed_dict
approach to getting data into the graph - but I am trying to learn how to use TensorFlow's TFRecord
files and how to use their built in batching functions.
Thanks!
Upvotes: 1
Views: 449
Reputation: 1593
Allen Lavoie pointed out the correct solution in a comment. The important missing piece was enqueue_many=True
as an argument to tf.train.shuffle_batch()
. The correct way to write those functions is:
def tfrecord_to_graph_ops(filenames_list):
file_queue = tf.train.string_input_producer(filenames_list)
reader = tf.TFRecordReader()
_, tfrecord = reader.read(file_queue)
tfrecord_features = tf.parse_single_example(
tfrecord,
features={'targets': tf.FixedLenFeature([], tf.string)}
)
targets = tf.decode_raw(tfrecord_features['targets'], tf.uint8)
targets = tf.reshape(targets, [-1])
return targets
def batch_generator(filenames_list, batch_size=BATCH_SIZE):
targets = tfrecord_to_graph_ops(filenames_list)
targets_batch = tf.train.shuffle_batch(
[targets],
batch_size=batch_size,
capacity=(20 * batch_size),
min_after_dequeue=(2 * batch_size),
enqueue_many=True
)
return targets_batch
Upvotes: 1