john price
john price

Reputation: 23

how to calculate mahalanobis distance in pytorch?

What is the most efficient way to calculate the mahalanobis distance: in pytorch?

enter image description here

Upvotes: 2

Views: 5677

Answers (2)

eugenio b
eugenio b

Reputation: 33

If you want to use an operation in broadcast, among batch_pattern and multiple clusters, here's a possible solution:

diff = (batch_pattern - mean_n_cluster).unsqueeze(2) 
left_term = torch.einsum('bcik,cki->bci',diff,inv_cov_n_cluster) 
under_radix = torch.einsum('bci,bci->bc',left_term,diff.squeeze(2))
mahalanobis_dist = torch.sqrt(under_radix)

where
batch_pattern (batch_size x 1 x n_features) is the input tensor representing the batch of points
mean_n_cluster (n_cluster x n_features) is the tensor representing the mean of the cluster
inv_cov_n_cluster (n_cluster x n_features x n_features) is the tensor representing the inverse of the covariance matrix

Upvotes: 0

Ivan
Ivan

Reputation: 40728

Based on SciPy's implementation of the mahalanobis distance, you would do this in PyTorch. Assuming u and v are 1D and cov is the 2D covariance matrix.

def mahalanobis(u, v, cov):
    delta = u - v
    m = torch.dot(delta, torch.matmul(torch.inverse(cov), delta))
    return torch.sqrt(m)

Note: scipy.spatial.distance.mahalanobis takes in the inverse of the covariance matrix.

Upvotes: 2

Related Questions