bemma
bemma

Reputation: 95

Extracting values from a tensor in tensorflow

If there are two tensor matrices

a = [[1 2 3 4][5 6 7 8]]
b = [[0 1][1 2]],

how can we get this:

c = [[1 2][6 7]]

i.e. from first row extracting column 0 and 1, from second row extracting column 1 and 2.

Upvotes: 0

Views: 2225

Answers (1)

javidcf
javidcf

Reputation: 59691

Here is a way to do that:

import tensorflow as tf

a = tf.constant([[1, 2, 3, 4],
                 [5, 6, 7, 8]])
b = tf.constant([[0, 1],
                 [1, 2]])
row = tf.range(tf.shape(a)[0])
row = tf.tile(row[:, tf.newaxis], (1, tf.shape(b)[1]))
idx = tf.stack([row, b], axis=-1)
c = tf.gather_nd(a, idx)
with tf.Session() as sess:
    print(sess.run(c))

Output:

[[1 2]
 [6 7]]

Upvotes: 1

Related Questions