Reputation: 2730
text_tensor
is a tensor in shape [None,sequence_max_length,embedding_dim]
that contains embedding look-up of a batch of sequences. The sequences are padded using zeros. I need to obtain a list named text_lengths
in shape [None]
(None is the batch size) that contains the length of each sequence without paddings. I've tried a couple of scripts.
The nearest I've got is the code below:
text_lens = tf.math.reduce_sum(tf.cast(tf.math.not_equal(text_tensor, tf.as_tensor(numpy.zeros([embedding_dim]))), dtype=tf.int32), axis=-1)
But still calculates the lengths incorrectly. Can anyone help me with this?
Upvotes: 1
Views: 238
Reputation: 9797
If I've understood this correctly, after the sequence's original length you get 0s of size embedding_dim
for the remaining indices of the first axis.
import tensorflow as tf
# batch_size = 2, first sequence length = 1, second sequence length = 3
data = [[[1, 1, 1, 1],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0]],
[[1, 1, 1, 1],
[1, 1, 1, 1],
[1, 1, 1, 1],
[0, 0, 0, 0]]]
with tf.compat.v1.Session() as sess:
tensor = tf.constant(data, dtype=tf.int32)
check = tf.reduce_all(tf.not_equal(tensor, 0), axis=-1)
lengths = tf.reduce_sum(tf.cast(check, tf.int32), axis=-1)
print(sess.run(lengths))
Output
[1 3]
Upvotes: 1