Reputation: 599
the sample tensor:
tensor([[ 0., 1., 2., 3., 4., 5.], # class1
[ 6., 7., 8., 9., 10., 11.], # class3
[12., 13., 14., 15., 16., 17.], # class2
[18., 19., 20., 21., 22., 23.], # class0
[24., 25., 26., 27., 28., 29.]. # class1
])
the expected result:
tensor([[18., 19., 20., 21., 22., 23.], # class0
[12., 13., 14., 15., 16., 17.], # class1
[12., 13., 14., 15., 16., 17.], # class2
[ 6., 7., 8., 9., 10., 11.]. # class3
])
Is there a pure PyTorch method to implement this?
Upvotes: 2
Views: 1572
Reputation: 114816
You can add according to class index using index_add
and then divide by the number of each label, computed using unique
:
# inputs
x = torch.arange(30.).view(5,6) # sample tensor
c = c = torch.tensor([1, 3, 2, 0, 1], dtype=torch.long) # class indices
# allocate space for output
result = torch.zeros((c.max() + 1, x.shape[1]), dtype=x.dtype)
# use index_add_ to sum up rows according to class
result.index_add_(0, c, x)
# use "unique" to count how many of each class
_, counts = torch.unique(c, return_counts=True)
# divide the sum by the counts to get the average
result /= counts[:, None]
The result
is as expected:
Out[*]:
tensor([[18., 19., 20., 21., 22., 23.],
[12., 13., 14., 15., 16., 17.],
[12., 13., 14., 15., 16., 17.],
[ 6., 7., 8., 9., 10., 11.]])
Upvotes: 2