Reputation: 33
Let say we have 2 matrices:
mat = torch.randn([20, 7]) * 100
mat2 = torch.randn([7, 20]) * 100
n, m = mat.shape
The simplest usual matrix multiplication looks like this:
def mat_vec_dot_product(mat, vect):
n, m = mat.shape
res = torch.zeros([n])
for i in range(n):
for j in range(m):
res[i] += mat[i][j] * vect[j]
return res
res = torch.zeros([n, n])
for k in range(n):
res[:, k] = mat_vec_dot_product(mat, mat2[:, k])
But what if I need to apply L2 norm instead of dot product? The code is next:
def mat_vec_l2_mult(mat, vect):
n, m = mat.shape
res = torch.zeros([n])
for i in range(n):
for j in range(m):
res[i] += (mat[i][j] - vect[j]) ** 2
res = res.sqrt()
return res
for k in range(n):
res[:, k] = mat_vec_l2_mult(mat, mat2[:, k])
Can we do this somehow in an optimal way using Torch or any other libraries? Cause naive O(n^3) Python code works really slow.
Upvotes: 3
Views: 2492
Reputation: 7713
Use torch.cdist
for L2 norm - euclidean distance
res = torch.cdist(mat, mat2.permute(1,0), p=2)
Here, I have used permute
to swap dim of mat2
from 7,20
to 20,7
Upvotes: 2
Reputation: 2493
First of all, matrix multiplication in PyTorch has a built-in operator: @
.
So, to multiply mat and mat2 you simply do:
mat @ mat2
(should work, assuming dimensions agree).
Now, to compute the Sum of Squared Differences(SSD, or L2-norm of differences) which you seem to compute in your second block, you can do a simple trick.
Since the squared L2-norm ||m_i - v||^2
(where m_i
is the i'th row of matrix M
and v
is the vector) is equal to the dot product <m_i - v, m_i-v>
- from linearity of the dot product you obtain: <m_i,m_i> - 2<m_i,v> + <v,v>
so you can compute the SSD of each row in M
from vector v
by computing once the squared L2-norm of each row, once the dot product between each row and the vector and once the L2-norm of the vector. This can be done in O(n^2)
.
However, for the SSD between 2 matrices you will still get O(n^3)
. Improvements can be made though by vectorizing the operations instead of using loops.
Here is a simple implementation for 2 matrices:
def mat_mat_l2_mult(mat,mat2):
rows_norm = (torch.norm(mat, dim=1, p=2, keepdim=True)**2).repeat(1,mat2.shape[1])
cols_norm = (torch.norm(mat2, dim=0, p=2, keepdim=True)**2).repeat(mat.shape[0], 1)
rows_cols_dot_product = mat @ mat2
ssd = rows_norm -2*rows_cols_dot_product + cols_norm
return ssd.sqrt()
mat = torch.randn([20, 7])
mat2 = torch.randn([7,20])
print(mat_mat_l2_mult(mat, mat2))
The resulting matrix will have at each cell i,j
the L2-norm of the difference between each row i
in mat
and each column j
in mat2
.
Upvotes: 2