Reputation: 7632
Is there a way in tensorflow to get the index of a value in the tensor?
e.g. I have a one-hot matrix, and I want to get the coordinates of the 1.
| 0 0 0 |
| 0 1 0 | => (1,1)
| 0 0 0 |
preferably this would be done with a function, e.g. tensor.index(binary_function)
Upvotes: 0
Views: 303
Reputation: 2013
You can use tf.where
.
For example,
import tensorflow as tf
x = tf.constant([[0, 0, 0],
[0, 1, 0],
[0, 3, 0]])
with tf.Session() as sess:
coordinates = tf.where(tf.greater(x, 0))
print(coordinates.eval()) # [[1 1], [2 1]]
print(tf.gather_nd(x, coordinates).eval()) # [1, 3]
Upvotes: 2