Reputation: 797
I have a tensor a
of type tf.int64
. I want to filter out this tensor on the basis of a given python list.
For example -
l = [1,2,3]
a = tf.constant([1,2,3,4], dtype=tf.int64)
Need a tensor with values 1,2,3
except 4
. That is filtering out a
on the basis of l
. How can I do this in TensorFlow?
Upvotes: 1
Views: 879
Reputation: 59681
You may use tf.sets.set_intersection
:
import tensorflow as tf
with tf.Graph().as_default(), tf.Session() as sess:
l = tf.constant([1, 2, 3], dtype=tf.int64)
a = tf.constant([1, 2, 3, 4], dtype=tf.int64)
# tf.sets.intersection in more recent versions
b = tf.sets.set_intersection(tf.expand_dims(a, 0), tf.expand_dims(l, 0))
b = tf.squeeze(tf.sparse.to_dense(b), 0)
print(sess.run(b))
# [1 2 3]
But this probably does not do what you want in many cases. If there are duplicate elements it will discard them, and it will sort the output too. More generally, you can just do this:
import tensorflow as tf
with tf.Graph().as_default(), tf.Session() as sess:
l = tf.constant([1, 2, 3], dtype=tf.int64)
a = tf.constant([1, 2, 3, 4], dtype=tf.int64)
m = tf.reduce_any(tf.equal(tf.expand_dims(a, 1), l), axis=1)
b = tf.boolean_mask(a, m)
print(sess.run(b))
# [1 2 3]
It is a quadratic comparison, but I don't think there is anything better like np.isin
in TensorFlow.
Upvotes: 2