Reputation: 1685
I have two tensors and both are of same shape. I want to calculate pairwise sinkhorn distance using GeomLoss
.
What i have tried:
import torch
import geomloss # pip install git+https://github.com/jeanfeydy/geomloss
a = torch.rand((8,4))
b = torch.rand((8,4))
geomloss.SamplesLoss('sinkhorn')(a,b)
# ^ input shape [batch, feature_dim]
# will return a scalar value
geomloss.SamplesLoss('sinkhorn')(a.unsqueeze(1),b.unsqueeze(1))
# ^ input shape [batch, n_points, feature_dim]
# will return a tensor of size [batch] of distances between a[i] and b[i] for each i
However I would like to compute pairwise distance where the resultant tensor should be of size [batch, batch]
. To achieve this, I tried the following to use broadcasting:
geomloss.SamplesLoss('sinkhorn')(a.unsqueeze(0), b.unsqueeze(1))
But I got this error message:
ValueError: Samples
x
andy
should have the same batchsize.
Upvotes: 0
Views: 975
Reputation: 40728
Since the documentation doesn't give examples on how to use the distance's forward function. Here's a way to do it, which will require you to call the distance function batch
times.
We will construct the distance matrix line by line. Line i
corresponds to the distances a[i]<->b[0]
, a[i]<->b[1]
, through to a[i]<->b[batch]
. To do so we need to construct, for each line i
, a (8x4)
repeated version of tensor a[i]
.
This will do:
a_i = torch.stack(8*[a[i]], dim=0)
Then we calculate the distance between a[i]
and each batch in b
:
dist(a_i.unsqueeze(1), b.unsqueeze(1))
Having a total of batch
lines we can construct our final tensor stack
.
Here's the complete code:
batch = a.shape[0]
dist = geomloss.SamplesLoss('sinkhorn')
distances = [dist(torch.stack(batch*[a[i]]).unsqueeze(1), b.unsqueeze(1)) for i in range(batch)]
D = torch.stack(distances)
Upvotes: 1