Oleg  Yarin
Oleg Yarin

Reputation: 161

How to filter tensorflow's Tensor based on tensor with indices?

Let's say I have a tensor of size [batch_size, 5, 10] called my_tensor. I also have an another tensor of size [batch_size, 1] holding indices called selecter.

I want to filter my_tensor with respect to selecter to produce new tensor of size [batch_size, 10], i.e. select only values that selecter contains. Basically, it's kinda reducing the middle dimension(which has size 5).

I feel like tf.where is the right choice, but not sure about it. I would really appreciate your help!

Upvotes: 0

Views: 1986

Answers (2)

omikron
omikron

Reputation: 2825

Alternative solution, works in Tensorflow 1.3:

max_selecter = tf.reduce_max(selecter) + 1
my_tensor = tf.boolean_mask(
    outputs,
    tf.logical_xor(
        tf.sequence_mask(my_tensor + 1, max_selecter),
        tf.sequence_mask(my_tensor, max_selecter)
    )
)

Upvotes: 0

P-Gn
P-Gn

Reputation: 24581

The solution is to go with tf.gather_nd.

tf.gather_nd(
    my_tensor,
    tf.stack([tf.range(batch_size), tf.squeeze(selecter)], axis=-1))

You can get rid of the squeeze if you construct selecter to be 1-D from the beginning.

Upvotes: 2

Related Questions