Reputation: 1299
Lets assume I have a batch consisting of two tensor, and the tensors in the patch are of size 3.
data = [[0.3, 0.5, 0.7], [-0.3, -0.5, -0.7]]
Now I want to extract from each tensor in the patch a single element base on an index:
index = [0, 2]
The output should therefore be
out = [0.3, -0.7] # Get index 0 from the first tensor in the batch and index 2 from the second tensor in the batch.
Of course this should be extendable to large batch sizes. The dimension of index
is equal to the batch size.
I tried to apply tf.gather
and tf.gather_nd
but I did not get the results I wanted.
For example the code below print 0.7
and not the desired result specified above:
data = [[0.3, 0.5, 0.7], [-0.3, -0.5, 0.7]]
index = [0, 2]
out = tf.gather_nd(data, index)
print(out.numpy())
Upvotes: 2
Views: 1017
Reputation: 11333
If you know the batch size you can do the following,
import tensorflow as tf
data = tf.constant([[0.3, 0.5, 0.7], [-0.3, -0.5, 0.7]])
index = [0,2]
gather_inds = np.stack([np.arange(len(index)), index], axis=1)
out = tf.gather_nd(data, gather_inds)
Why your gather didn't work is because you are gathering from the inner most dimension. Therefore, your indices need to be as same as the rank of your data
tensor. In other words, your indices should be,
[0,0] and [1,2]
Upvotes: 2