sariii
sariii

Reputation: 2140

how can I get top smallest tensor values based on a condition in tensorflow

I have a tensor like this:

sim_topics = [[0.65 0.   0.   0.   0.42  0.   0.   0.51 0.   0.34 0.]
              [0.   0.51 0.   0.  0.52  0.   0.   0.   0.53 0.42 0.]
              [0.    0.32 0.  0.50 0.34  0.   0.   0.39 0.32 0.52 0.]
              [0.    0.23 0.37 0.   0.    0.37 0.37 0.   0.47 0.39 0.3 ]]

and one boolean tensor like this:

bool_t = [False  True  True  True]

I want to select part of sim_topics based on the bool flag in bool_t in a way it just select top k smallest values per row(if the row is true if not leave it as it is).

So the expected output would be like this:(here k=2)

[[0.65 0.   0.    0.   0.42  0.   0.   0.51 0.   0.34 0.]
 [0.   0.51 0.   0.  0.52  0.   0.   0.   0.53 0.42 0.]
 [0.   0.32 0.    0.50 0    0    0.   0.   0    0.32 0 ]
 [0.   0.23 0     0.   0.    0    0    0.   0    0    0.3 ]]

I was trying to accomplish this first by using boolean_mask and where to get the indices I want then go get the top smallest. However, when I use where it does not give me the indices where there is zero.

Upvotes: 5

Views: 360

Answers (1)

Onyambu
Onyambu

Reputation: 79208

k = 2
dim0 = sim_topics.shape[0]
a = tf.cast(tf.equal(sim_topics,0), sim_topics.dtype)
b = tf.reshape(tf.reduce_sum(a,1) + k, (dim0,-1))
c = tf.cast(tf.argsort(tf.argsort(sim_topics,1),1), sim_topics.dtype)
d = tf.logical_or(tf.less(c,b),tf.reshape(tf.logical_not(bool_t),(dim0,-1)))
with tf.Session() as sess:
  print(sess.run(sim_topics * tf.cast(d,sim_topics.dtype)))


[[0.65 0.   0.   0.   0.42 0.   0.   0.51 0.   0.34 0.  ]
 [0.   0.51 0.   0.   0.   0.   0.   0.   0.   0.42 0.  ]
 [0.   0.32 0.   0.   0.   0.   0.   0.   0.32 0.   0.  ]
 [0.   0.23 0.   0.   0.   0.   0.   0.   0.   0.   0.3 ]]

Upvotes: 2

Related Questions