Reputation: 371
I looking for an elegant way to select a subset of a torch tensor which satisfies some constrains. For example, say I have:
A = torch.rand(10,2)-1
and S
is a 10x1 tensor,
sel = torch.ge(S,5) -- this is a ByteTensor
I would like to be able to do logical indexing, as follows:
A1 = A[sel]
But that doesn't work.
So there's the index
function which accepts a LongTensor
but I could not find a simple way to convert S
to a LongTensor
, except the following:
sel = torch.nonzero(sel)
which returns a K x 2 tensor (K being the number of values of S >= 5). So then I have to convert it to a 1 dimensional array, which finally allows me to index A:
A:index(1,torch.squeeze(sel:select(2,1)))
This is very cumbersome; in e.g. Matlab all I'd have to do is
A(S>=5,:)
Can anyone suggest a better way?
Upvotes: 7
Views: 1985
Reputation: 4718
There are two simpler alternatives:
Use maskedSelect
:
result=A:maskedSelect(your_byte_tensor)
Use a simple element-wise multiplication, for example
result=torch.cmul(A,S:gt(0))
The second one is very useful if you need to keep the shape of the original matrix (i.e A), for example to select neurons in a layer at backprop. However, since it puts zeros in the resulting matrix whenever the condition dictated by the ByteTensor doesn't apply, you can't use it to compute the product (or median, etc.). The first one only returns the elements that satisfy the condittion, so this is what I'd use to compute products or medians or any other thing where I don't want zeros.
Upvotes: 0
Reputation: 16121
One possible alternative is:
sel = S:ge(5):expandAs(A) -- now you can use this mask with the [] operator
A1 = A[sel]:unfold(1, 2, 2) -- unfold to get back a 2D tensor
Example:
> A = torch.rand(3,2)-1
-0.0047 -0.7976
-0.2653 -0.4582
-0.9713 -0.9660
[torch.DoubleTensor of size 3x2]
> S = torch.Tensor{{6}, {1}, {5}}
6
1
5
[torch.DoubleTensor of size 3x1]
> sel = S:ge(5):expandAs(A)
1 1
0 0
1 1
[torch.ByteTensor of size 3x2]
> A[sel]
-0.0047
-0.7976
-0.9713
-0.9660
[torch.DoubleTensor of size 4]
> A[sel]:unfold(1, 2, 2)
-0.0047 -0.7976
-0.9713 -0.9660
[torch.DoubleTensor of size 2x2]
Upvotes: 6