Alex Savinov
Alex Savinov

Reputation: 33

In PyTorch calc Euclidean distance instead of matrix multiplication

Let say we have 2 matrices:

mat = torch.randn([20, 7]) * 100
mat2 = torch.randn([7, 20]) * 100

n, m = mat.shape

The simplest usual matrix multiplication looks like this:

def mat_vec_dot_product(mat, vect):
    n, m = mat.shape
    
    res = torch.zeros([n])
    for i in range(n):
        for j in range(m):
            res[i] += mat[i][j] * vect[j]
        
    return res

res = torch.zeros([n, n])
for k in range(n):
    res[:, k] = mat_vec_dot_product(mat, mat2[:, k])
    

But what if I need to apply L2 norm instead of dot product? The code is next:

def mat_vec_l2_mult(mat, vect):
    n, m = mat.shape
    
    res = torch.zeros([n])
    for i in range(n):
        for j in range(m):
            res[i] += (mat[i][j] - vect[j]) ** 2
            
    res = res.sqrt()
        
    return res

for k in range(n):
    res[:, k] = mat_vec_l2_mult(mat, mat2[:, k])

Can we do this somehow in an optimal way using Torch or any other libraries? Cause naive O(n^3) Python code works really slow.

Upvotes: 3

Views: 2492

Answers (2)

Dishin H Goyani
Dishin H Goyani

Reputation: 7713

Use torch.cdist for L2 norm - euclidean distance

res = torch.cdist(mat, mat2.permute(1,0), p=2)

Here, I have used permute to swap dim of mat2 from 7,20 to 20,7

Upvotes: 2

Gil Pinsky
Gil Pinsky

Reputation: 2493

First of all, matrix multiplication in PyTorch has a built-in operator: @. So, to multiply mat and mat2 you simply do:

mat @ mat2

(should work, assuming dimensions agree).

Now, to compute the Sum of Squared Differences(SSD, or L2-norm of differences) which you seem to compute in your second block, you can do a simple trick. Since the squared L2-norm ||m_i - v||^2 (where m_i is the i'th row of matrix M and v is the vector) is equal to the dot product <m_i - v, m_i-v> - from linearity of the dot product you obtain: <m_i,m_i> - 2<m_i,v> + <v,v> so you can compute the SSD of each row in M from vector v by computing once the squared L2-norm of each row, once the dot product between each row and the vector and once the L2-norm of the vector. This can be done in O(n^2). However, for the SSD between 2 matrices you will still get O(n^3). Improvements can be made though by vectorizing the operations instead of using loops. Here is a simple implementation for 2 matrices:

def mat_mat_l2_mult(mat,mat2):
    rows_norm = (torch.norm(mat, dim=1, p=2, keepdim=True)**2).repeat(1,mat2.shape[1])
    cols_norm = (torch.norm(mat2, dim=0, p=2, keepdim=True)**2).repeat(mat.shape[0], 1)
    rows_cols_dot_product = mat @ mat2
    ssd = rows_norm -2*rows_cols_dot_product + cols_norm
    return ssd.sqrt()

mat = torch.randn([20, 7])
mat2 = torch.randn([7,20])
print(mat_mat_l2_mult(mat, mat2))

The resulting matrix will have at each cell i,j the L2-norm of the difference between each row i in mat and each column j in mat2.

Upvotes: 2

Related Questions