Sengiley
Sengiley

Reputation: 289

Masked aggregations in pytorch

Given data and mask tensors are there a pytorch-way to obtain masked aggregations of data (mean, max, min, etc.)?

x = torch.tensor([
    [1, 2, -1, -1],
    [10, 20, 30, -1]
])

mask = torch.tensor([
    [True, True, False, False],
    [True, True, True, False]
])

To compute a masked mean I can do the following, yet are there any pytorch built-in or commonly used package to do that?

n_mask = torch.sum(mask, axis=1)
x_mean = torch.sum(x * mask, axis=1) / n_mask

print(x_mean)
> tensor([ 1.50, 20.00])

Upvotes: 0

Views: 29

Answers (1)

Karl
Karl

Reputation: 5473

If you don't want to use torch.masked due to it being in prototype stage, you can use scatter_reduce to aggregate based on sum, prod, mean, amax and amin.

x = torch.tensor([
    [1, 2, -1, -1],
    [10, 20, 30, -1]
]).float() # note you'll need to cast to float for this to work

mask = torch.tensor([
    [True, True, False, False],
    [True, True, True, False]
])

rows, cols = mask.nonzero().T

for reduction in ['mean', 'sum', 'prod', 'amax', 'amin']:
    output = torch.zeros(x.shape[0], device=x.device, dtype=x.dtype)
    output = output.scatter_reduce(0, rows, x[rows, cols], reduce=reduction, include_self=False)
    print(f"{reduction}\t{output}")
    

# # printed output:
# mean  tensor([ 1.5000, 20.0000])
# sum   tensor([ 3., 60.])
# prod  tensor([2.0000e+00, 6.0000e+03])
# amax  tensor([ 2., 30.])
# amin  tensor([ 1., 10.])

Upvotes: 2

Related Questions