zalsaeed
zalsaeed

Reputation: 81

How to retain 2D (or more) shape when using pytrorch masked_select

Suppose I have the following two matching shape tensors:

a = tensor([[ 0.0113, -0.1666,  0.5960, -0.0667], [-0.0977, -0.1984,  0.5153,  0.0420]])
selectors = tensor([[ True,  True, False, False], [ True,  False, True, False]])

When using torch.masked_select to find the values in a that match True indices in selectors like this:

torch.masked_select(a, selectors)

The output will be in 1D shape instead of the original 2D shape:

tensor([ 0.0113, -0.1666, -0.0977, 0.5153])

This is consistent with masked_select behavior as it is given in the documentation (torch.masked_select). However, my goal is to get a result that matches the shape of the two original tensors. I.e.:

tensor([[0.0113, -0.1666], [-0.0977, 0.5153]])

Is there a way to get this without having to loop over all the elements in the tensors and find the mask for each one? Please note that I have also looked into using torch.where, but it doesn't fit the case I have as I see it.

Upvotes: 4

Views: 3134

Answers (1)

elad
elad

Reputation: 31

As @jodag pointed out, for general inputs, each row on the desired masked result might have a different number of elements, depending on how many True values there are on the same row in selectors. However, you could overcome this by allowing trailing zero padding in the result.

Basic solution:

indices = torch.masked_fill(torch.cumsum(selectors.int(), dim=1), ~selectors, 0)
masked = torch.scatter(input=torch.zeros_like(a), dim=1, index=indices, src=a)[:,1:]

Explanation: By applying cumsum() row-wise over selectors, we compute for each unmasked element in a the target column number it should be copied to in the output tensor. Then, scatter() performs a row-wise scattering of a's elements to these computed target locations. We leave all masked elements with the index 0, so that the first element in each row of the result would contain one of the masked elements (maybe arbitrarily. we don't care which). We then ignore these un-wanted 1st values by taking the slice [:,1:]. The output resulting masked tensor has the exact same size as the input a (this is the maximum needed size, for the case where there is a row of full True values in selectors).

Usage example:

>>> a = Torch.tensor([[ 1,  2,  3,  4,  5,  6], [10, 20, 30, 40, 50, 60]])
>>> selectors = Torch.tensor([[ True, False, False,  True, False,  True], [False, False,  True,  True, False, False]])
>>> torch.cumsum(selectors.int(), dim=1)
tensor([[1, 1, 1, 2, 2, 3],
        [0, 0, 1, 2, 2, 2]])
>>> indices = torch.masked_fill(torch.cumsum(selectors.int(), dim=1), ~selectors, 0)
>>> indices
tensor([[1, 0, 0, 2, 0, 3],
        [0, 0, 1, 2, 0, 0]])
>>> torch.scatter(input=torch.zeros_like(a), dim=1, index=indices, src=a)
tensor([[ 5,  1,  4,  6,  0,  0],
        [60, 30, 40,  0,  0,  0]])
>>> torch.scatter(input=torch.zeros_like(a), dim=1, index=indices, src=a)[:,1:]
tensor([[ 1,  4,  6,  0,  0],
        [30, 40,  0,  0,  0]])

Adapting output size: Here, the length of dim=1 of the output resulting masked tensor is the max number of un-masked items in a row. For your original show-case, the output shape would be (2,2) as you desired. Note that if this number is not previously known and a is on CUDA, it would cause an additional host-device synchronization that might affect the performance. To do so, instead of allocating input=torch.zeros_like(a) for scatter(), allocate it by a.new_zeros(size=(a.size(0), torch.max(indices).item() + 1)). The +1 is for the 1st place which is later sliced-out. The host-device synchronization would occur by accessing the result of max() to calculate the allocated output size.

Example:

>>> torch.scatter(input=a.new_zeros(size=(a.size(0), torch.max(indices).item() + 1)), dim=1, index=indices, src=a)[:,1:]
tensor([[ 1,  4,  6],
        [30, 40,  0]])

Changing the padding value: If another custom default value is wanted as a padding, one could use torch.full_like(my_custom_value) rather than torch.zeros_like() when allocating the output for scatter().

Upvotes: 2

Related Questions