Abhilash Awasthi
Abhilash Awasthi

Reputation: 797

Filter a tensor on the basis of a python list in TensorFlow

I have a tensor a of type tf.int64. I want to filter out this tensor on the basis of a given python list.
For example -

l = [1,2,3]
a = tf.constant([1,2,3,4], dtype=tf.int64) 

Need a tensor with values 1,2,3 except 4. That is filtering out a on the basis of l. How can I do this in TensorFlow?

Upvotes: 1

Views: 879

Answers (1)

javidcf
javidcf

Reputation: 59681

You may use tf.sets.set_intersection:

import tensorflow as tf

with tf.Graph().as_default(), tf.Session() as sess:
    l = tf.constant([1, 2, 3], dtype=tf.int64)
    a = tf.constant([1, 2, 3, 4], dtype=tf.int64)
    # tf.sets.intersection in more recent versions
    b = tf.sets.set_intersection(tf.expand_dims(a, 0), tf.expand_dims(l, 0))
    b = tf.squeeze(tf.sparse.to_dense(b), 0)
    print(sess.run(b))
    # [1 2 3]

But this probably does not do what you want in many cases. If there are duplicate elements it will discard them, and it will sort the output too. More generally, you can just do this:

import tensorflow as tf

with tf.Graph().as_default(), tf.Session() as sess:
    l = tf.constant([1, 2, 3], dtype=tf.int64)
    a = tf.constant([1, 2, 3, 4], dtype=tf.int64)
    m = tf.reduce_any(tf.equal(tf.expand_dims(a, 1), l), axis=1)
    b = tf.boolean_mask(a, m)
    print(sess.run(b))
    # [1 2 3]

It is a quadratic comparison, but I don't think there is anything better like np.isin in TensorFlow.

Upvotes: 2

Related Questions