Reputation: 95
If there are two tensor matrices
a = [[1 2 3 4][5 6 7 8]]
b = [[0 1][1 2]],
how can we get this:
c = [[1 2][6 7]]
i.e. from first row extracting column 0 and 1, from second row extracting column 1 and 2.
Upvotes: 0
Views: 2225
Reputation: 59691
Here is a way to do that:
import tensorflow as tf
a = tf.constant([[1, 2, 3, 4],
[5, 6, 7, 8]])
b = tf.constant([[0, 1],
[1, 2]])
row = tf.range(tf.shape(a)[0])
row = tf.tile(row[:, tf.newaxis], (1, tf.shape(b)[1]))
idx = tf.stack([row, b], axis=-1)
c = tf.gather_nd(a, idx)
with tf.Session() as sess:
print(sess.run(c))
Output:
[[1 2]
[6 7]]
Upvotes: 1