NicoJ
NicoJ

Reputation: 98

Tensorflow tf.gather with axis parameter

I am using tensorflow's tf.gather to get elements from a multidimensional array like this:

import tensorflow as tf

indices = tf.constant([0, 1, 1])
x = tf.constant([[1, 2, 3],
                 [4, 5, 6],
                 [7, 8, 9]])

result = tf.gather(x, indices, axis=1)

with tf.Session() as sess:
    selection = sess.run(result)
    print(selection)

which results in:

[[1 2 2]
 [4 5 5]
 [7 8 8]]

what I want though is:

[1
 5
 8]

how can I use tf.gather to apply the single indices on the specified axis? (Same result as the workaround specified in this answer: https://stackoverflow.com/a/41845855/9763766)

Upvotes: 6

Views: 2275

Answers (1)

Vijay Mariappan
Vijay Mariappan

Reputation: 17201

You need to convert the indices to full indices, and using gather_nd. Can be achieved by doing:

result = tf.squeeze(tf.gather_nd(x,tf.stack([tf.range(indices.shape[0])[...,tf.newaxis], indices[...,tf.newaxis]], axis=2)))

Upvotes: 2

Related Questions