Reputation: 2140
I have a tensor like this:
sim_topics = [[0.65 0. 0. 0. 0.42 0. 0. 0.51 0. 0.34 0.]
[0. 0.51 0. 0. 0.52 0. 0. 0. 0.53 0.42 0.]
[0. 0.32 0. 0.50 0.34 0. 0. 0.39 0.32 0.52 0.]
[0. 0.23 0.37 0. 0. 0.37 0.37 0. 0.47 0.39 0.3 ]]
and one boolean tensor like this:
bool_t = [False True True True]
I want to select part of sim_topics
based on the bool flag in bool_t
in a way it just select top k smallest
values per row(if the row is true if not leave it as it is).
So the expected output would be like this:(here k=2
)
[[0.65 0. 0. 0. 0.42 0. 0. 0.51 0. 0.34 0.]
[0. 0.51 0. 0. 0.52 0. 0. 0. 0.53 0.42 0.]
[0. 0.32 0. 0.50 0 0 0. 0. 0 0.32 0 ]
[0. 0.23 0 0. 0. 0 0 0. 0 0 0.3 ]]
I was trying to accomplish this first by using boolean_mask
and where
to get the indices I want then go get the top smallest. However, when I use where
it does not give me the indices where there is zero
.
Upvotes: 5
Views: 360
Reputation: 79208
k = 2
dim0 = sim_topics.shape[0]
a = tf.cast(tf.equal(sim_topics,0), sim_topics.dtype)
b = tf.reshape(tf.reduce_sum(a,1) + k, (dim0,-1))
c = tf.cast(tf.argsort(tf.argsort(sim_topics,1),1), sim_topics.dtype)
d = tf.logical_or(tf.less(c,b),tf.reshape(tf.logical_not(bool_t),(dim0,-1)))
with tf.Session() as sess:
print(sess.run(sim_topics * tf.cast(d,sim_topics.dtype)))
[[0.65 0. 0. 0. 0.42 0. 0. 0.51 0. 0.34 0. ]
[0. 0.51 0. 0. 0. 0. 0. 0. 0. 0.42 0. ]
[0. 0.32 0. 0. 0. 0. 0. 0. 0.32 0. 0. ]
[0. 0.23 0. 0. 0. 0. 0. 0. 0. 0. 0.3 ]]
Upvotes: 2