Reputation: 81
Suppose I have the following two matching shape tensors:
a = tensor([[ 0.0113, -0.1666, 0.5960, -0.0667], [-0.0977, -0.1984, 0.5153, 0.0420]])
selectors = tensor([[ True, True, False, False], [ True, False, True, False]])
When using torch.masked_select
to find the values in a
that match True indices in selectors
like this:
torch.masked_select(a, selectors)
The output will be in 1D shape instead of the original 2D shape:
tensor([ 0.0113, -0.1666, -0.0977, 0.5153])
This is consistent with masked_select
behavior as it is given in the documentation (torch.masked_select). However, my goal is to get a result that matches the shape of the two original tensors. I.e.:
tensor([[0.0113, -0.1666], [-0.0977, 0.5153]])
Is there a way to get this without having to loop over all the elements in the tensors and find the mask for each one? Please note that I have also looked into using torch.where
, but it doesn't fit the case I have as I see it.
Upvotes: 4
Views: 3134
Reputation: 31
As @jodag pointed out, for general inputs, each row on the desired masked result might have a different number of elements, depending on how many True
values there are on the same row in selectors
. However, you could overcome this by allowing trailing zero padding in the result.
Basic solution:
indices = torch.masked_fill(torch.cumsum(selectors.int(), dim=1), ~selectors, 0)
masked = torch.scatter(input=torch.zeros_like(a), dim=1, index=indices, src=a)[:,1:]
Explanation:
By applying cumsum()
row-wise over selectors
, we compute for each unmasked element in a
the target column number it should be copied to in the output tensor. Then, scatter()
performs a row-wise scattering of a
's elements to these computed target locations. We leave all masked elements with the index 0
, so that the first element in each row of the result would contain one of the masked elements (maybe arbitrarily. we don't care which). We then ignore these un-wanted 1st values by taking the slice [:,1:]
. The output resulting masked tensor has the exact same size as the input a
(this is the maximum needed size, for the case where there is a row of full True
values in selectors
).
Usage example:
>>> a = Torch.tensor([[ 1, 2, 3, 4, 5, 6], [10, 20, 30, 40, 50, 60]])
>>> selectors = Torch.tensor([[ True, False, False, True, False, True], [False, False, True, True, False, False]])
>>> torch.cumsum(selectors.int(), dim=1)
tensor([[1, 1, 1, 2, 2, 3],
[0, 0, 1, 2, 2, 2]])
>>> indices = torch.masked_fill(torch.cumsum(selectors.int(), dim=1), ~selectors, 0)
>>> indices
tensor([[1, 0, 0, 2, 0, 3],
[0, 0, 1, 2, 0, 0]])
>>> torch.scatter(input=torch.zeros_like(a), dim=1, index=indices, src=a)
tensor([[ 5, 1, 4, 6, 0, 0],
[60, 30, 40, 0, 0, 0]])
>>> torch.scatter(input=torch.zeros_like(a), dim=1, index=indices, src=a)[:,1:]
tensor([[ 1, 4, 6, 0, 0],
[30, 40, 0, 0, 0]])
Adapting output size: Here, the length of dim=1 of the output resulting masked tensor is the max number of un-masked items in a row. For your original show-case, the output shape would be (2,2) as you desired. Note that if this number is not previously known and a
is on CUDA, it would cause an additional host-device synchronization that might affect the performance.
To do so, instead of allocating input=torch.zeros_like(a)
for scatter()
, allocate it by a.new_zeros(size=(a.size(0), torch.max(indices).item() + 1))
. The +1
is for the 1st place which is later sliced-out. The host-device synchronization would occur by accessing the result of max()
to calculate the allocated output size.
Example:
>>> torch.scatter(input=a.new_zeros(size=(a.size(0), torch.max(indices).item() + 1)), dim=1, index=indices, src=a)[:,1:]
tensor([[ 1, 4, 6],
[30, 40, 0]])
Changing the padding value: If another custom default value is wanted as a padding, one could use torch.full_like(my_custom_value)
rather than torch.zeros_like()
when allocating the output for scatter()
.
Upvotes: 2