Tanay Rastogi
Tanay Rastogi

Reputation: 81

Minimum and Mean Euclidean distance between two tensors of different shape

I am quite new to Pytorch and currently running into issues with Memory Overflow.

Task: I have two 2D tensors of respective shapes A: [1000, 14] & B: [100000, 14].

I have to find the distance of each row of tensor-A from all rows from tensor-B. Later using the calculated distance values, I find the mean of minimum/mean distance of each row of tensor-A from tensor-B.

Current Solution: My solution to calculate minimum distance:

dist = list()
for row_id in range(A.shape[0]):
      # Mean distance of a row in A from B
      dist.append(torch.linalg.norm(A[row_id, :] - B, dim=1).min().item())
result = torch.FloatTensor(dist).mean()

And solution to calculate minimum mean distance:

dist = list()
for row_id in range(A.shape[0]):
      # Mean distance of a row in A from B
      dist.append(torch.linalg.norm(A[row_id, :] - B, dim=1).mean().item())
result = torch.FloatTensor(dist).mean()

Issue: This gives me result but is either very slow (if run on CPU) or often leads to memory overflow in GPU when trying to run on GPU. (I have a T4 GPU - 8GB)

Can you please recommend me a better solution to calculate the Euclidean distance that is faster and does not lead to overflow issues?

Thanks!

Upvotes: 2

Views: 478

Answers (2)

Tanay Rastogi
Tanay Rastogi

Reputation: 81

So I asked the same question in Pytorch Discssions as well. There they gave me the reply back using torch.cdist().

distAB = torch.cdist (A, B)
# Mean minimum distance
resultMinMeanB = distAB.min (dim = 1).values.mean()
# Mean distance
resultMeanMeanB = distAB.mean()

I implemented this approach now. I see a speed boost on my code. The running time for this function evaluation have reduced by nearly 3 times.

Upvotes: 2

Yakov Dan
Yakov Dan

Reputation: 3372

Sure.

The idea is to use the fact that norm(x-y) = norm(x)^2 + norm(y)^2 -2xy and the outer product.

So, see the following code for the case of minimum:

import torch

A = torch.randn(1000, 14)
B = torch.randn(100000, 14)

dist = list()
for row_id in range(A.shape[0]):
      # Mean distance of a row in A from B
      dist.append(torch.linalg.norm(A[row_id, :] - B, dim=1).min().item())
result = torch.FloatTensor(dist).mean()

x = torch.linalg.norm(A, dim=1)**2
y = torch.linalg.norm(B, dim=1)**2
o1 = torch.outer(x, torch.ones(B.shape[0]))
o2 = torch.outer(torch.ones(A.shape[0]), y)
n = o1+o2 - 2 * [email protected]()
s = torch.sqrt(n)

print(torch.allclose(result, s.min(dim=1)[0].mean()))

The last line compares the result of your implementation with mine.

Explanation:

x = torch.linalg.norm(A, dim=1)**2 computes a vector whose elements are squared norms of the rows of A.

Similarly for y = torch.linalg.norm(B, dim=1)**2 and the rows of B. o1 = torch.outer(x, torch.ones(B.shape[0])) is a matrix of as many identical columns as there are rows in B. Each column is a vector of square norms of A.

Similarly, o2 = torch.outer(torch.ones(A.shape[0]), y) is a matrix with as many identical rows as there rows in A and each row has the squared norms of the rows of B.

So, the matrix o1+o2' is such that at indices i, j` the value is the squared norm of the ith row of A plush the squared norm of the jth row of B.

What remain is to subtract twice the inner product of the ith row of A with jth row of B which is done using n = o1+o2 - 2 * [email protected]()

Now, sqrt(n) has the euclidean distance between the ith row of A and jth row of B at indices i, j. What's left is to find the minimum or mean, in this case - minimum.

Upvotes: 1

Related Questions