Tue
Tue

Reputation: 432

pairwise/rowwise comparison of pytorch tensor

I have a 2D tensor representing integer coordinates on a grid. And I would like to check my tensor for any occurences of a specific coordinate (x,y)

A psuedo-code example:

positions = torch.arange(20).repeat(2).view(-1,2)
xy_dst1 = torch.tensor((5,7))
xy_dst2 = torch.tensor((4,5))
positions == xy_dst1 # should give none
positions == xy_dst2 # should give index 2 and 12

My only solution so far is to convert the tensors to lists or tuples and then go through them iteratively, but with the conversions back and forth and the iterations that can't be a very good solution. Does anyone know of a better solution that stays in the tensor framework?

Upvotes: 1

Views: 415

Answers (1)

kmkurn
kmkurn

Reputation: 685

Try

def check(positions, xy):
    return (positions == xy.view(1, 2)).all(dim=1).nonzero()

print(check(positions, xy_dst1))
# Output: tensor([], size=(0, 1), dtype=torch.int64)

print(check(positions, xy_dst2))
# Output:
# tensor([[ 2],
#         [12]])

Upvotes: 1

Related Questions