Dotan
Dotan

Reputation: 7632

index() function in tensorflow?

Is there a way in tensorflow to get the index of a value in the tensor?

e.g. I have a one-hot matrix, and I want to get the coordinates of the 1.

| 0 0 0 |
| 0 1 0 |  => (1,1)
| 0 0 0 |

preferably this would be done with a function, e.g. tensor.index(binary_function)

Upvotes: 0

Views: 303

Answers (1)

Meuu
Meuu

Reputation: 2013

You can use tf.where.

For example,

import tensorflow as tf

x = tf.constant([[0, 0, 0],
                 [0, 1, 0],
                 [0, 3, 0]])

with tf.Session() as sess:
    coordinates = tf.where(tf.greater(x, 0))
    print(coordinates.eval()) # [[1 1], [2 1]]
    print(tf.gather_nd(x, coordinates).eval()) # [1, 3]

Upvotes: 2

Related Questions