sebjwallace
sebjwallace

Reputation: 795

tensorflow - top k values in vector to binary vector

Say I had a vector with values [0,4,1,2,8,7,0,2], how can I get a binary vector of the top k values (k = 3) [0,1,0,0,1,1,0,0] in tensorflow?

Upvotes: 2

Views: 365

Answers (1)

Dmytro Prylipko
Dmytro Prylipko

Reputation: 5064

The TensorFlow's tf.math.top_k will find the values for you. But to obtain a binary mask, you need tf.scatter_nd.

This code must work for the task:

x = tf.convert_to_tensor([0,4,1,2,8,7,0,2])
_, indices = tf.math.top_k(x, k=3)
result = tf.scatter_nd(tf.expand_dims(indices, 1), tf.ones_like(indices), tf.shape(x))

Output:

<tf.Tensor: id=47, shape=(8,), dtype=int32, numpy=array([0, 1, 0, 0, 1, 1, 0, 0], dtype=int32)>

Please note, that before v1.13, the top_k operation is under tf.nn.top_k:

_, indices = tf.nn.top_k(x, k=3)

Upvotes: 1

Related Questions