Reputation: 280
Given a Tensor A of shape (N,C) and an indices Tensor Idx of shape (N,), i'd like to sum all the elements of each row in A excluding the corresponding column index in I. For example:
A = torch.tensor([[1,2,3],
[4,5,6]])
Idx = torch.tensor([0,2])
#result:
torch.tensor([[5],
[9]])
A solution using loops is known.
Upvotes: 0
Views: 1936
Reputation: 2542
You can set excluded elements to zero:
A[range(A.shape[0]),Idx] = 0
and sum tensor along rows:
b = A.sum(dim = 1,keepdim = True ) # b = torch.tensor([[5], [9]])
Upvotes: 1