Andreas Pasternak
Andreas Pasternak

Reputation: 1299

Extract one element per batch from Tensorflow 2.1 tensor

Lets assume I have a batch consisting of two tensor, and the tensors in the patch are of size 3.

data = [[0.3, 0.5, 0.7], [-0.3, -0.5, -0.7]]

Now I want to extract from each tensor in the patch a single element base on an index:

index = [0, 2]

The output should therefore be

out = [0.3, -0.7] # Get index 0 from the first tensor in the batch and index 2 from the second tensor in the batch.

Of course this should be extendable to large batch sizes. The dimension of index is equal to the batch size.

I tried to apply tf.gather and tf.gather_nd but I did not get the results I wanted.

For example the code below print 0.7 and not the desired result specified above:

data = [[0.3, 0.5, 0.7], [-0.3, -0.5, 0.7]]

index = [0, 2]
out = tf.gather_nd(data, index)

print(out.numpy())

Upvotes: 2

Views: 1017

Answers (1)

thushv89
thushv89

Reputation: 11333

If you know the batch size you can do the following,

import tensorflow as tf
data = tf.constant([[0.3, 0.5, 0.7], [-0.3, -0.5, 0.7]])

index = [0,2]
gather_inds = np.stack([np.arange(len(index)), index], axis=1)
out = tf.gather_nd(data, gather_inds)

Why your gather didn't work is because you are gathering from the inner most dimension. Therefore, your indices need to be as same as the rank of your data tensor. In other words, your indices should be,

[0,0] and [1,2]

Upvotes: 2

Related Questions