Kaushik Roy
Kaushik Roy

Reputation: 1685

Calculate Batch Pairwise Sinkhorn Distance in PyTorch

I have two tensors and both are of same shape. I want to calculate pairwise sinkhorn distance using GeomLoss.

What i have tried:

import torch
import geomloss  # pip install git+https://github.com/jeanfeydy/geomloss

a = torch.rand((8,4))
b = torch.rand((8,4))

geomloss.SamplesLoss('sinkhorn')(a,b)
# ^ input shape [batch, feature_dim]
# will return a scalar value

geomloss.SamplesLoss('sinkhorn')(a.unsqueeze(1),b.unsqueeze(1))  
# ^ input shape [batch, n_points, feature_dim]
# will return a tensor of size [batch] of distances between a[i] and b[i] for each i

However I would like to compute pairwise distance where the resultant tensor should be of size [batch, batch]. To achieve this, I tried the following to use broadcasting:

geomloss.SamplesLoss('sinkhorn')(a.unsqueeze(0), b.unsqueeze(1))

But I got this error message:

ValueError: Samples x and y should have the same batchsize.

Upvotes: 0

Views: 975

Answers (1)

Ivan
Ivan

Reputation: 40728

Since the documentation doesn't give examples on how to use the distance's forward function. Here's a way to do it, which will require you to call the distance function batch times.

We will construct the distance matrix line by line. Line i corresponds to the distances a[i]<->b[0], a[i]<->b[1], through to a[i]<->b[batch]. To do so we need to construct, for each line i, a (8x4) repeated version of tensor a[i].

This will do:

a_i = torch.stack(8*[a[i]], dim=0)

Then we calculate the distance between a[i] and each batch in b:

dist(a_i.unsqueeze(1), b.unsqueeze(1))

Having a total of batch lines we can construct our final tensor stack.


Here's the complete code:

batch = a.shape[0]
dist = geomloss.SamplesLoss('sinkhorn')
distances = [dist(torch.stack(batch*[a[i]]).unsqueeze(1), b.unsqueeze(1)) for i in range(batch)]
D = torch.stack(distances)

Upvotes: 1

Related Questions