Reputation: 795
Say I had a vector with values [0,4,1,2,8,7,0,2]
, how can I get a binary vector of the top k values (k = 3) [0,1,0,0,1,1,0,0]
in tensorflow?
Upvotes: 2
Views: 365
Reputation: 5064
The TensorFlow's tf.math.top_k
will find the values for you. But to obtain a binary mask, you need tf.scatter_nd
.
This code must work for the task:
x = tf.convert_to_tensor([0,4,1,2,8,7,0,2])
_, indices = tf.math.top_k(x, k=3)
result = tf.scatter_nd(tf.expand_dims(indices, 1), tf.ones_like(indices), tf.shape(x))
Output:
<tf.Tensor: id=47, shape=(8,), dtype=int32, numpy=array([0, 1, 0, 0, 1, 1, 0, 0], dtype=int32)>
Please note, that before v1.13, the top_k
operation is under tf.nn.top_k
:
_, indices = tf.nn.top_k(x, k=3)
Upvotes: 1