Fosa
Fosa

Reputation: 572

Calculating Euclidian Norm in Pytorch.. Trouble understanding an implementation

I've seen another StackOverflow thread talking about the various implementations for calculating the Euclidian norm and I'm having trouble seeing why/how a particular implementation works.

The code is found in an implementation of the MMD metric: https://github.com/josipd/torch-two-sample/blob/master/torch_two_sample/statistics_diff.py

Here is some beginning boilerplate:

import torch
sample_1, sample_2 = torch.ones((10,2)), torch.zeros((10,2))

Then the next part is where we pick up from the code above.. I'm unsure why the samples are being concatenated together..

sample_12 = torch.cat((sample_1, sample_2), 0)
distances = pdist(sample_12, sample_12, norm=2)

and are then passed to the pdist function:

def pdist(sample_1, sample_2, norm=2, eps=1e-5):
    r"""Compute the matrix of all squared pairwise distances.
    Arguments
    ---------
    sample_1 : torch.Tensor or Variable
        The first sample, should be of shape ``(n_1, d)``.
    sample_2 : torch.Tensor or Variable
        The second sample, should be of shape ``(n_2, d)``.
    norm : float
        The l_p norm to be used.
    Returns
    -------
    torch.Tensor or Variable
        Matrix of shape (n_1, n_2). The [i, j]-th entry is equal to
        ``|| sample_1[i, :] - sample_2[j, :] ||_p``."""

here we get to the meat of the calculation

    n_1, n_2 = sample_1.size(0), sample_2.size(0)
    norm = float(norm)
    if norm == 2.:
        norms_1 = torch.sum(sample_1**2, dim=1, keepdim=True)
        norms_2 = torch.sum(sample_2**2, dim=1, keepdim=True)
        norms = (norms_1.expand(n_1, n_2) +
             norms_2.transpose(0, 1).expand(n_1, n_2))
        distances_squared = norms - 2 * sample_1.mm(sample_2.t())
        return torch.sqrt(eps + torch.abs(distances_squared))

I am at a loss for why the euclidian norm would be calculated this way. Any insight would be greatly appreciated

Upvotes: 5

Views: 9060

Answers (1)

Milo Lu
Milo Lu

Reputation: 3366

Let's walk through this block of code step by step. The definition of Euclidean distance, i.e., L2 norm is

enter image description here

Let's consider the simplest case. We have two samples,

enter image description here

Sample a has two vectors [a00, a01] and [a10, a11]. Same for sample b. Let first calculate the norm

n1, n2 = a.size(0), b.size(0)  # here both n1 and n2 have the value 2
norm1 = torch.sum(a**2, dim=1)
norm2 = torch.sum(b**2, dim=1)

Now we get

enter image description here

Next, we have norms_1.expand(n_1, n_2) and norms_2.transpose(0, 1).expand(n_1, n_2)

enter image description here

Note that b is transposed. The sum of the two gives norm

enter image description here

sample_1.mm(sample_2.t()), that's the multiplication of the two matrix.

enter image description here

Therefore, after the operation

distances_squared = norms - 2 * sample_1.mm(sample_2.t())

you get

enter image description here

In the end, the last step is taking the square root of every element in the matrix.

Upvotes: 13

Related Questions