PinkBanter
PinkBanter

Reputation: 1976

How to use masked select for these kind of tensors?

Suppose, I got a tensor a and a tensor b

import torch
a = torch.tensor([[[ 0.8856,  0.1411, -0.1856, -0.1425],
    [-0.0971,  0.1251,  0.1608, -0.1302],
    [-0.0901,  0.3215,  0.1763, -0.0412]], 
    [[ 0.8856,  0.1411, -0.1856, -0.1425],
    [-0.0971,  0.1251,  0.1608, -0.1302],
    [-0.0901,  0.3215,  0.1763, -0.0412]]])
b = torch.tensor([[0,
    2,
    1], 
    [0,
    2,
    1]])

Now, I would like to select indices from tensor a, where the value of tensor b is not 0.

pred_masks = ( b != 0 )
c = torch.masked_select( a, (pred_masks == 1))

And of course, I get an expected error.

----> 1 c = torch.masked_select( a, (pred_masks == 1))

RuntimeError: The size of tensor a (4) must match the size of tensor b (3) at non-singleton dimension 2

This is caused by the nested list containing 4 items. However, it is required to select all the values of the nested list at index x in tensor a, corresponding to the index x in tensor b.

I will be grateful for any hint or answer.

Upvotes: 2

Views: 4965

Answers (1)

Niklas Höpner
Niklas Höpner

Reputation: 424

I am not so sure what you want as the shape of the output c. Since your mask is of shape (2,3) and a is of shape (2,3,4) do you want as output a tensor of shape (n,4) where n is the number of elements that is true in the (2,3)-mask ?

If yes then I would suggest just using the mask as an index for the first two dimensions.

c = a[pred_masks,:]

Hope that helps a bit.

Upvotes: 3

Related Questions