Reputation: 1127
For example,
import tensorflow as tf
index = tf.constant([[1],[1]])
values = tf.constant([[0.2, 0.8],[0.4, 0.6]])
if I use extract = tf.gather_nd(values, index)
the return is
[[0.4 0.6]
[0.4 0.6]]
However, I want the result is
[[0.8], [0.6]]
where the index is along axis = 1, however, there is no axis parameter setting in tf.gather_nd.
What should I do? Thanks!
Upvotes: 1
Views: 1492
Reputation: 402493
Concatenate a range to index
:
index = tf.stack([tf.range(index.shape[0])[:, None], index], axis=2)
result = tf.gather_nd(values, index)
result.eval(session=tf.Session())
array([[0.8],
[0.6]], dtype=float32)
Upvotes: 2