BiBi
BiBi

Reputation: 7908

Batch sequential data with tf.data

Let's consider a toy dataset, ordered, with two features:

This data basically consists of two flattened sequences concatenated, 1, 2, 3, 4, 5 (sequence 0), and 111, 222, 333, 444, 555 (sequence 1).

I would like to generate sequences of size t (say 3) consisting of consecutive elements from the same sequence (sequence_id), I do not want a sequence to have elements belonging to different sequence_id.

For instance, without any shuffling, I would like to get the following batches:

I know how to generate sequence data using tf.data.Dataset.window or tf.data.Dataset.batch, but I do not know how to prevent a sequence from containing a mix of different sequence_id (e.g. the sequence 4, 5, 111 should not be valid as it mixes elements from sequence 0 and sequence 1).

Below is my failed attempt:

import tensorflow as tf

data = tf.data.Dataset.from_tensor_slices(([1, 2, 3, 4, 5, 111, 222, 333, 444, 555], 
                                           [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]))\
                .window(3, 1, drop_remainder=True)\
                .repeat(-1)\
                .flat_map(lambda x, y: x.batch(3))\
                .batch(10)
data_it = data.make_initializable_iterator()
next_element = data_it.get_next()

with tf.Session() as sess:
    sess.run(data_it.initializer)
    print(sess.run(next_element))

that outputs:

[[  1   2   3]   # good
 [  2   3   4]   # good
 [  3   4   5]   # good
 [  4   5 111]   # bad – mix of sequence 0 (4, 5) and sequence 1 (111)
 [  5 111 222]   # bad
 [111 222 333]   # good
 [222 333 444]   # good
 [333 444 555]   # good
 [  1   2   3]   # good
 [  2   3   4]]  # good

Upvotes: 1

Views: 1158

Answers (1)

giser_yugang
giser_yugang

Reputation: 6166

You can use filter() to judge if the sequence_id is consistent. Because filter() transformation does not currently support nested datasets as inputs, so you need zip().

import tensorflow as tf

data = tf.data.Dataset.from_tensor_slices(([1, 2, 3, 4, 5, 111, 222, 333, 444, 555],
                                           [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]))\
                .window(3, 1, drop_remainder=True) \
                .flat_map(lambda x, y: tf.data.Dataset.zip((x,y)).batch(3))\
                .filter(lambda x,y: tf.equal(tf.size(tf.unique(y)[0]),1))\
                .map(lambda x,y:x)\
                .repeat(-1)\
                .batch(10)
data_it = data.make_initializable_iterator()
next_element = data_it.get_next()

with tf.Session() as sess:
    sess.run(data_it.initializer)
    print(sess.run(next_element))

[[  1   2   3]
 [  2   3   4]
 [  3   4   5]
 [111 222 333]
 [222 333 444]
 [333 444 555]
 [  1   2   3]
 [  2   3   4]
 [  3   4   5]
 [111 222 333]]

Upvotes: 2

Related Questions