CodeHoarder
CodeHoarder

Reputation: 280

Torch sum each row excluding an index

Given a Tensor A of shape (N,C) and an indices Tensor Idx of shape (N,), i'd like to sum all the elements of each row in A excluding the corresponding column index in I. For example:

A = torch.tensor([[1,2,3],
                  [4,5,6]])

Idx = torch.tensor([0,2])

#result:
torch.tensor([[5],
              [9]])

A solution using loops is known.

Upvotes: 0

Views: 1936

Answers (1)

Anton Ganichev
Anton Ganichev

Reputation: 2542

You can set excluded elements to zero:

A[range(A.shape[0]),Idx] = 0

and sum tensor along rows:

b = A.sum(dim = 1,keepdim = True ) # b = torch.tensor([[5],  [9]])

Upvotes: 1

Related Questions