Reputation: 89
I have a tensor called my_tensor with tha shape of [batch_size, seq_length]
and I have another tensor named idx with tha shape of [batch_size, 1]
which is comprised of indices which start at 0 and finish at "seq_length".
I want to extract the values of in each row of my_tensor on using the indices defined in idx.
I tried to use tf.gather_nd
and tf.gather
but I was not successful.
Consider the following example:
batch_size = 3
seq_length = 5
idx = [2, 0, 4]
my_tensor = tf.random.uniform(shape=(batch_size, seq_length))
I want to get the values at
[[0, 2],
[1, 0],
[3, 4]]
from my_tensor.
I have to do further process over them, so I would like to have them at the same time (I don't know if it is even possible) and in an efficient way; however, I could not come up with any other methods.
I appreciate any help :)
Upvotes: 1
Views: 661
Reputation: 14495
The trick is to first convert your set of indices into a boolean mask which you can then use to reduce my_tensor
as you have described using the boolean_mask operation.
You can accomplish this by one-hot encoding the idx
tensor.
So, where idx = [2, 0, 4]
we can do tf.one_hot(idx, seq_length)
in order to convert it to something like this:
[ [0., 0., 1., 0., 0.],
[1., 0., 0., 0., 0.],
[0., 0., 0., 0., 1.] ]
Then, putting it all together for, say my_tensor
:
[ [0.6413697 , 0.4079175 , 0.42499018, 0.3037368 , 0.8580252 ],
[0.8698617 , 0.29096508, 0.11531639, 0.25421357, 0.5844104 ],
[0.6442119 , 0.31816053, 0.6245482 , 0.7249261 , 0.7595779 ] ]
we can proceed as follows:
result = tf.boolean_mask(my_tensor, tf.one_hot(idx,seq_length))
to give:
[0.42499018, 0.8698617 , 0.7595779 ]
as expected
Upvotes: 1