Reputation: 23
What is the most efficient way to calculate the mahalanobis distance: in pytorch?
Upvotes: 2
Views: 5677
Reputation: 33
If you want to use an operation in broadcast, among batch_pattern and multiple clusters, here's a possible solution:
diff = (batch_pattern - mean_n_cluster).unsqueeze(2)
left_term = torch.einsum('bcik,cki->bci',diff,inv_cov_n_cluster)
under_radix = torch.einsum('bci,bci->bc',left_term,diff.squeeze(2))
mahalanobis_dist = torch.sqrt(under_radix)
where
batch_pattern (batch_size x 1 x n_features) is the input tensor representing the batch of points
mean_n_cluster (n_cluster x n_features) is the tensor representing the mean of the cluster
inv_cov_n_cluster (n_cluster x n_features x n_features) is the tensor representing the inverse of the covariance matrix
Upvotes: 0
Reputation: 40728
Based on SciPy's implementation of the mahalanobis distance, you would do this in PyTorch. Assuming u
and v
are 1D and cov
is the 2D covariance matrix.
def mahalanobis(u, v, cov):
delta = u - v
m = torch.dot(delta, torch.matmul(torch.inverse(cov), delta))
return torch.sqrt(m)
Note: scipy.spatial.distance.mahalanobis
takes in the inverse of the covariance matrix.
Upvotes: 2