bluesummers
bluesummers

Reputation: 12647

TensorFlow tf.data.Dataset and bucketing

For an LSTM network, I've seen great improvements with bucketing.

I've come across the bucketing section in the TensorFlow docs which (tf.contrib).

Though in my network, I am using the tf.data.Dataset API, specifically I'm working with TFRecords, so my input pipeline looks something like this

dataset = tf.data.TFRecordDataset(TFRECORDS_PATH)
dataset = dataset.map(_parse_function)
dataset = dataset.map(_scale_function)
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.padded_batch(batch_size, padded_shapes={.....})

How can I incorporate the bucketing method into a the tf.data.Dataset pipeline?

If it matters, in every record in the TFRecords file I have the sequence length saved as an integer.

Upvotes: 9

Views: 2417

Answers (1)

Vijay Mariappan
Vijay Mariappan

Reputation: 17201

Various bucketing use cases using Dataset API are explained well here.

bucket_by_sequence_length() example:

def elements_gen():
   text = [[1, 2, 3], [3, 4, 5, 6, 7], [1, 2], [8, 9, 0, 2]]
   label = [1, 2, 1, 2]
   for x, y in zip(text, label):
       yield (x, y)

def element_length_fn(x, y):
   return tf.shape(x)[0]

dataset = tf.data.Dataset.from_generator(generator=elements_gen,
                                     output_shapes=([None],[]),
                                     output_types=(tf.int32, tf.int32))

dataset =   dataset.apply(tf.contrib.data.bucket_by_sequence_length(element_length_func=element_length_fn,
                                                              bucket_batch_sizes=[2, 2, 2],
                                                              bucket_boundaries=[0, 8]))

batch = dataset.make_one_shot_iterator().get_next()

with tf.Session() as sess:

   for _ in range(2):
      print('Get_next:')
      print(sess.run(batch))

Output:

Get_next:
(array([[1, 2, 3, 0, 0],
   [3, 4, 5, 6, 7]], dtype=int32), array([1, 2], dtype=int32))
Get_next:
(array([[1, 2, 0, 0],
   [8, 9, 0, 2]], dtype=int32), array([1, 2], dtype=int32))

Upvotes: 6

Related Questions