Reputation: 161
Let's say I have a tensor of size [batch_size, 5, 10]
called my_tensor
.
I also have an another tensor of size [batch_size, 1]
holding indices called selecter
.
I want to filter my_tensor
with respect to selecter
to produce new tensor of size [batch_size, 10]
, i.e. select only values that selecter
contains. Basically, it's kinda reducing the middle dimension(which has size 5).
I feel like tf.where
is the right choice, but not sure about it.
I would really appreciate your help!
Upvotes: 0
Views: 1986
Reputation: 2825
Alternative solution, works in Tensorflow 1.3:
max_selecter = tf.reduce_max(selecter) + 1
my_tensor = tf.boolean_mask(
outputs,
tf.logical_xor(
tf.sequence_mask(my_tensor + 1, max_selecter),
tf.sequence_mask(my_tensor, max_selecter)
)
)
Upvotes: 0
Reputation: 24581
The solution is to go with tf.gather_nd
.
tf.gather_nd(
my_tensor,
tf.stack([tf.range(batch_size), tf.squeeze(selecter)], axis=-1))
You can get rid of the squeeze
if you construct selecter
to be 1-D from the beginning.
Upvotes: 2