Scoops
Scoops

Reputation: 61

Check equality of any top k entries in rows of tensor A against argmax in rows of tensor B

New to tensors/pytorch.

I have two 2d tensors, A and B.

A contains floats that represent the probability assigned to a certain index. B contains a one-hot binary vector in the correct index.

A
tensor([[0.1, 0.4, 0.5],
        [0.5, 0.4, 0.1],
        [0.4, 0.5, 0.1]])

B
tensor([[0, 0, 1],
        [0, 1, 0],
        [0, 0, 1]])

I would like to find the number of rows where the index of any top-k values of A match the one-hot index in B. In this case, k=2.

My attempt:

tops = torch.topk(A, 2, dim=1)

top_idx = tops.indices

top_2_matches = torch.where((torch.any(top_idx, 1) == B.argmax(dim=1)))
      

If done properly, the example should return a tensor([0, 1]), since the first 2 rows have top-2 matches, but I get (tensor([1]),) as a return.

Not sure where I'm going wrong here. Thanks for any help!

Upvotes: 1

Views: 108

Answers (1)

A. Maman
A. Maman

Reputation: 980

Try this:

top_idx = torch.topk(A, 2, dim=1).indices

row_indicator = (top_idx == B.argmax(dim=1).unsqueeze(dim=1)).any(dim=1)

top_2_matches = torch.arange(len(row_indicator))[row_indicator]

For example:

>>> import torch
>>> A = torch.tensor([[0.1, 0.4, 0.5],
...                   [0.5, 0.4, 0.1],
...                   [0.4, 0.5, 0.1]])
>>> B = torch.tensor([[0, 0, 1],
...                   [0, 1, 0],
...                   [0, 0, 1]])
>>> tops = torch.topk(A, 2, dim=1)
>>>tops
torch.return_types.topk(
values=tensor([[0.5000, 0.4000],
               [0.5000, 0.4000],
               [0.5000, 0.4000]]),
indices=tensor([[2, 1],
                [0, 1],
                [1, 0]]))
>>> top_idx = tops.indices
>>> top_idx
tensor([[2, 1],
        [0, 1],
        [1, 0]])
>>> index_indicator = top_idx == B.argmax(dim=1).unsqueeze(dim=1)
>>> index_indicator
tensor([[ True, False],
        [False,  True],
        [False, False]])
>>> row_indicator = index_indicator.any(dim=1)
>>> row_indicator
tensor([ True,  True, False])
>>> top_2_matches = torch.arange(len(row_indicator))[row_indicator]
>>> top_2_matches
tensor([0, 1])

Upvotes: 1

Related Questions