beepretty
beepretty

Reputation: 1127

how to gather element with index in tensorflow

For example,

import tensorflow as tf

index = tf.constant([[1],[1]])
values = tf.constant([[0.2, 0.8],[0.4, 0.6]])

if I use extract = tf.gather_nd(values, index) the return is

[[0.4 0.6]
 [0.4 0.6]]

However, I want the result is

[[0.8], [0.6]]

where the index is along axis = 1, however, there is no axis parameter setting in tf.gather_nd.

What should I do? Thanks!

Upvotes: 1

Views: 1492

Answers (1)

cs95
cs95

Reputation: 402493

Concatenate a range to index:

index = tf.stack([tf.range(index.shape[0])[:, None], index], axis=2)
result = tf.gather_nd(values, index)

result.eval(session=tf.Session())
array([[0.8],
       [0.6]], dtype=float32)

Upvotes: 2

Related Questions