pete
pete

Reputation: 601

Summing vector pairs efficiently in pytorch

I'm trying to calculate the summation of each pair of rows in a matrix. Suppose I have an m x n matrix, say one like

[[1,2,3],
 [4,5,6],
 [7,8,9]]

and I want to create a matrix of the summations of all pairs of rows. So, for the above matrix, we would want

[[5,7,9],
 [8,10,12],
 [11,13,15]]

In general, I think the new matrix will be (m choose 2) x n. For the above example in pytorch, I ran

import torch

x = torch.tensor([[1,2,3], [4,5,6], [7,8,9]])

y = x[None] + x[:, None]

torch.cat((y[0, 1:3, :], y[1, 2:3, :]))

which manually creates the matrix I am looking for. However, I am struggling to think of a way to create the output without manually specifying indices and without using a for-loop. Is there even a way to create such a matrix for an arbitrary matrix without the use of a for-loop?

Upvotes: 0

Views: 557

Answers (1)

jhso
jhso

Reputation: 3283

You can try using this function:

def sum_rows(x): 
   y = x[None] + x[:, None] 
   ind = torch.tril_indices(x.shape[0], x.shape[0], offset=-1)
   return y[ind[0], ind[1]]

Because you know you want pairs with the constraints of sum_matrix[i,j], where i<j (but i>j would also work), you can just specify that you want the lower/upper triangle indices of your 3D matrix. This still uses a for loop, AFAIK, but should do the job for variable-sized inputs.

Upvotes: 1

Related Questions