Reputation: 7908
Let's consider a toy dataset, ordered, with two features:
value
(e.g. 1, 2, 3, 4, 5, 111, 222, 333, 444, 555
)sequence_id
(e.g. 0, 0, 0, 0, 0, 1, 1, 1, 1, 1
)This data basically consists of two flattened sequences concatenated, 1, 2, 3, 4, 5
(sequence 0
), and 111, 222, 333, 444, 555
(sequence 1
).
I would like to generate sequences of size t
(say 3
) consisting of consecutive elements from the same sequence (sequence_id
), I do not want a sequence to have elements belonging to different sequence_id
.
For instance, without any shuffling, I would like to get the following batches:
1, 2, 3
,2, 3, 4
,3, 4, 5
,111, 222, 333
,222, 333, 444
,333, 444, 555
, 1, 2, 3
,I know how to generate sequence data using tf.data.Dataset.window
or tf.data.Dataset.batch
, but I do not know how to prevent a sequence from containing a mix of different sequence_id
(e.g. the sequence 4, 5, 111
should not be valid as it mixes elements from sequence 0
and sequence 1
).
Below is my failed attempt:
import tensorflow as tf
data = tf.data.Dataset.from_tensor_slices(([1, 2, 3, 4, 5, 111, 222, 333, 444, 555],
[0, 0, 0, 0, 0, 1, 1, 1, 1, 1]))\
.window(3, 1, drop_remainder=True)\
.repeat(-1)\
.flat_map(lambda x, y: x.batch(3))\
.batch(10)
data_it = data.make_initializable_iterator()
next_element = data_it.get_next()
with tf.Session() as sess:
sess.run(data_it.initializer)
print(sess.run(next_element))
that outputs:
[[ 1 2 3] # good
[ 2 3 4] # good
[ 3 4 5] # good
[ 4 5 111] # bad – mix of sequence 0 (4, 5) and sequence 1 (111)
[ 5 111 222] # bad
[111 222 333] # good
[222 333 444] # good
[333 444 555] # good
[ 1 2 3] # good
[ 2 3 4]] # good
Upvotes: 1
Views: 1158
Reputation: 6166
You can use filter()
to judge if the sequence_id
is consistent. Because filter()
transformation does not currently support nested datasets as inputs, so you need zip()
.
import tensorflow as tf
data = tf.data.Dataset.from_tensor_slices(([1, 2, 3, 4, 5, 111, 222, 333, 444, 555],
[0, 0, 0, 0, 0, 1, 1, 1, 1, 1]))\
.window(3, 1, drop_remainder=True) \
.flat_map(lambda x, y: tf.data.Dataset.zip((x,y)).batch(3))\
.filter(lambda x,y: tf.equal(tf.size(tf.unique(y)[0]),1))\
.map(lambda x,y:x)\
.repeat(-1)\
.batch(10)
data_it = data.make_initializable_iterator()
next_element = data_it.get_next()
with tf.Session() as sess:
sess.run(data_it.initializer)
print(sess.run(next_element))
[[ 1 2 3]
[ 2 3 4]
[ 3 4 5]
[111 222 333]
[222 333 444]
[333 444 555]
[ 1 2 3]
[ 2 3 4]
[ 3 4 5]
[111 222 333]]
Upvotes: 2