Reputation: 731
I have a 2x2
reference tensor and a batch of candidate 2x2
tensors. I would like to find the closest candidate tensor to the reference tensor by summed euclidean distance over the identically indexed (except for the batch index) elements.
For example:
ref = torch.as_tensor([[1, 2], [3, 4]])
candidates = torch.rand(100, 2, 2)
I would like to find the 2x2
tensor index
in candidates
that minimizes:
(ref[0][0] - candidates[index][0][0])**2 +
(ref[0][1] - candidates[index][0][1])**2 +
(ref[1][0] - candidates[index][1][0])**2 +
(ref[1][1] - candidates[index][1][1])**2
Ideally, this solution would work for arbitrary dimension reference tensor of size (b, c, d, ...., z)
and an arbitrary batch_size
of candidate tensors with equal dimensions to the reference tensor (batch_size, b, c, d,..., z)
Upvotes: 0
Views: 90
Reputation: 1680
Elaborating on @ndrwnaguib's answer, it should be rather:
dist = torch.cdist( ref.float().flatten().unsqueeze(0), candidates.flatten(start_dim=1))
print(torch.square( dist ))
torch.argmin( dist )
tensor([[23.3516, 21.8078, 25.5247, 26.3465, 21.3161, 17.7537, 24.1075, 22.4388,
22.7513, 16.8489]])
tensor(9)
other options, worth noting:
torch.square(ref.float()- candidates).sum( dim=(1,2) )
tensor([[23.3516, 21.8078, 25.5247, 26.3465, 21.3161, 17.7537, 24.1075, 22.4388,
22.7513, 16.8489]])
diff = ref.float()- candidates
torch.einsum( "abc,abc->a" ,diff, diff)
tensor([[23.3516, 21.8078, 25.5247, 26.3465, 21.3161, 17.7537, 24.1075, 22.4388,
22.7513, 16.8489]])
Upvotes: 1