Reputation: 1976
Suppose, I got a tensor a
and a tensor b
import torch
a = torch.tensor([[[ 0.8856, 0.1411, -0.1856, -0.1425],
[-0.0971, 0.1251, 0.1608, -0.1302],
[-0.0901, 0.3215, 0.1763, -0.0412]],
[[ 0.8856, 0.1411, -0.1856, -0.1425],
[-0.0971, 0.1251, 0.1608, -0.1302],
[-0.0901, 0.3215, 0.1763, -0.0412]]])
b = torch.tensor([[0,
2,
1],
[0,
2,
1]])
Now, I would like to select indices from tensor a
, where the value of tensor b
is not 0.
pred_masks = ( b != 0 )
c = torch.masked_select( a, (pred_masks == 1))
And of course, I get an expected error.
----> 1 c = torch.masked_select( a, (pred_masks == 1))
RuntimeError: The size of tensor a (4) must match the size of tensor b (3) at non-singleton dimension 2
This is caused by the nested list containing 4 items. However, it is required to select all the values of the nested list at index x in tensor a
, corresponding to the index x in tensor b
.
I will be grateful for any hint or answer.
Upvotes: 2
Views: 4965
Reputation: 424
I am not so sure what you want as the shape of the output c. Since your mask is of shape (2,3) and a is of shape (2,3,4) do you want as output a tensor of shape (n,4) where n is the number of elements that is true in the (2,3)-mask ?
If yes then I would suggest just using the mask as an index for the first two dimensions.
c = a[pred_masks,:]
Hope that helps a bit.
Upvotes: 3