Reputation: 432
I have a 2D tensor representing integer coordinates on a grid. And I would like to check my tensor for any occurences of a specific coordinate (x,y)
A psuedo-code example:
positions = torch.arange(20).repeat(2).view(-1,2)
xy_dst1 = torch.tensor((5,7))
xy_dst2 = torch.tensor((4,5))
positions == xy_dst1 # should give none
positions == xy_dst2 # should give index 2 and 12
My only solution so far is to convert the tensors to lists or tuples and then go through them iteratively, but with the conversions back and forth and the iterations that can't be a very good solution. Does anyone know of a better solution that stays in the tensor framework?
Upvotes: 1
Views: 415
Reputation: 685
Try
def check(positions, xy):
return (positions == xy.view(1, 2)).all(dim=1).nonzero()
print(check(positions, xy_dst1))
# Output: tensor([], size=(0, 1), dtype=torch.int64)
print(check(positions, xy_dst2))
# Output:
# tensor([[ 2],
# [12]])
Upvotes: 1