Amir Rosenfeld
Amir Rosenfeld

Reputation: 371

torch logical indexing of tensor

I looking for an elegant way to select a subset of a torch tensor which satisfies some constrains. For example, say I have:

A = torch.rand(10,2)-1

and S is a 10x1 tensor,

sel = torch.ge(S,5) -- this is a ByteTensor

I would like to be able to do logical indexing, as follows:

A1 = A[sel]

But that doesn't work. So there's the index function which accepts a LongTensor but I could not find a simple way to convert S to a LongTensor, except the following:

sel = torch.nonzero(sel)

which returns a K x 2 tensor (K being the number of values of S >= 5). So then I have to convert it to a 1 dimensional array, which finally allows me to index A:

A:index(1,torch.squeeze(sel:select(2,1)))

This is very cumbersome; in e.g. Matlab all I'd have to do is

A(S>=5,:)

Can anyone suggest a better way?

Upvotes: 7

Views: 1985

Answers (2)

Ash
Ash

Reputation: 4718

There are two simpler alternatives:

  1. Use maskedSelect:

    result=A:maskedSelect(your_byte_tensor)

  2. Use a simple element-wise multiplication, for example

    result=torch.cmul(A,S:gt(0))

The second one is very useful if you need to keep the shape of the original matrix (i.e A), for example to select neurons in a layer at backprop. However, since it puts zeros in the resulting matrix whenever the condition dictated by the ByteTensor doesn't apply, you can't use it to compute the product (or median, etc.). The first one only returns the elements that satisfy the condittion, so this is what I'd use to compute products or medians or any other thing where I don't want zeros.

Upvotes: 0

deltheil
deltheil

Reputation: 16121

One possible alternative is:

sel = S:ge(5):expandAs(A)   -- now you can use this mask with the [] operator
A1 = A[sel]:unfold(1, 2, 2) -- unfold to get back a 2D tensor

Example:

> A = torch.rand(3,2)-1
-0.0047 -0.7976
-0.2653 -0.4582
-0.9713 -0.9660
[torch.DoubleTensor of size 3x2]

> S = torch.Tensor{{6}, {1}, {5}}
 6
 1
 5
[torch.DoubleTensor of size 3x1]

> sel = S:ge(5):expandAs(A)
1  1
0  0
1  1
[torch.ByteTensor of size 3x2]

> A[sel]
-0.0047
-0.7976
-0.9713
-0.9660
[torch.DoubleTensor of size 4]

> A[sel]:unfold(1, 2, 2)
-0.0047 -0.7976
-0.9713 -0.9660
[torch.DoubleTensor of size 2x2]

Upvotes: 6

Related Questions