user3669481
user3669481

Reputation: 327

Creating a tf.sequence_mask based on input values

The example given by the tensorflow tutorial shows that the mask can be created:

tf.sequence_mask([1, 3, 2], 5)  # [[True, False, False, False, False],
                                #  [True, True, True, False, False],
                                #  [True, True, False, False, False]]

What if I would like to create a dynamic mask based on the values of my batches? Say if my input is [[1, 0, 2, 3, 4], [2, 3, 4, 4, 4], [2, 3, 4, 5, 4]], and I would like to mask everything up to the first 4 to be True and everything after the first 4 to be false, and the resulting mask should be:

[[True, True, True, True, True],
[True, True, True, False, False],
[True, True, True, False, False]]

I am trying to use this as weight to apply to my sequence_loss tensor

Upvotes: 0

Views: 2764

Answers (1)

akuiper
akuiper

Reputation: 214927

import tensorflow as tf
x = tf.constant([[1, 0, 2, 3, 4], [2, 3, 4, 4, 4], [2, 3, 4, 5, 4]])
​
cond = tf.cast(tf.equal(x, 4), tf.int8)
idx4_ = tf.reshape(tf.argmax(cond, axis=1, output_type=tf.int32), (-1,1))
​

Optional if all rows have at least one value equal to 4:

idx4 = tf.where(
    tf.equal(tf.reduce_max(cond, axis=1, keep_dims=True), 1), 
    idx4_, 
    tf.constant(-1, shape=idx4_.shape)
)

Create the mask by comparing the index of first 4 with a 1d range index:

mask = idx4 >= tf.range(x.shape[1])
​
with tf.Session() as sess:
    print(sess.run(mask))
#[[ True  True  True  True  True]
# [ True  True  True False False]
# [ True  True  True False False]]

Or use sequence_mask:

import tensorflow as tf
x = tf.constant([[1, 0, 2, 3, 4], [2, 3, 4, 4, 4], [2, 3, 4, 5, 4]])
​
cond = tf.cast(tf.equal(x, 4), tf.int8)
idx4_ = tf.argmax(cond, axis=1, output_type=tf.int32)

idx4 = tf.where(
    tf.equal(tf.reduce_max(cond, axis=1), 1), 
    idx4_, 
    tf.constant(-1, shape=idx4_.shape)
)

with tf.Session() as sess:
    print(sess.run(tf.sequence_mask(idx4+1, x.shape[1])))

#[[ True  True  True  True  True]
# [ True  True  True False False]
# [ True  True  True False False]]

If x is a placeholder with unknown shape before hand:

import tensorflow as tf
​
x = tf.placeholder(tf.int32, shape=[None,None])
cond = tf.cast(tf.equal(x, 4), tf.int8)
idx4_ = tf.argmax(cond, axis=1, output_type=tf.int32)

idx4 = tf.where(
    tf.equal(tf.reduce_max(cond, axis=1), 1), 
    idx4_, 
    tf.fill(tf.shape(idx4_), -1)
)

mask = tf.sequence_mask(idx4+1, tf.shape(x)[-1])
with tf.Session() as sess:
    print(sess.run(mask, {x: [[1, 0, 2, 3, 4], [2, 3, 4, 4, 4], [2, 3, 4, 5, 4]]}))

#[[ True  True  True  True  True]
# [ True  True  True False False]
# [ True  True  True False False]]

Upvotes: 3

Related Questions