Munichong
Munichong

Reputation: 4031

Tensorflow get indices of values in a tensor

Given a matrix and vector, I want to find the indices of the values in the corresponding rows of the matrix.

m = tf.constant([[0, 2, 1],[2, 0, 1]])  # matrix
y = tf.constant([1,2])  # values whose indices should be found

the ideal output is [2,0] because the first value of y, 1, is at the index 2 of the first vector of m. The second value of y, 2, is at the index 0 of the second vector of m.

Upvotes: 6

Views: 15554

Answers (1)

Munichong
Munichong

Reputation: 4031

I just find one solution. But I do not know if there is any better ones.

m = tf.constant([[0, 2, 1],[2, 0, 1]])  # matrix
y = tf.constant([1,2])  # values whose indices should be found
y = tf.reshape(y, (y.shape[0], 1))  # [[1], [2]]
cols = tf.where(tf.equal(m, y))[:,-1]  # [2,0]

init = tf.global_variables_initializer()
with tf.Session() as sess:
    init.run()
    print(sess.run(cols))

The above outputs: [2, 0]

Upvotes: 10

Related Questions