Mitiku
Mitiku

Reputation: 5412

Top K indices of a multi-dimensional tensor

I have a 2D tensor and I want to get the indices of the top k values. I know about pytorch's topk function. The problem with pytorch's topk function is, it computes the topk values over some dimension. I want to get topk values over both dimensions.

For example for the following tensor

a = torch.tensor([[4, 9, 7, 4, 0],
        [8, 1, 3, 1, 0],
        [9, 8, 4, 4, 8],
        [0, 9, 4, 7, 8],
        [8, 8, 0, 1, 4]])

pytorch's topk function will give me the following.

values, indices = torch.topk(a, 3)

print(indices)
# tensor([[1, 2, 0],
#        [0, 2, 1],
#        [0, 1, 4],
#        [1, 4, 3],
#        [1, 0, 4]])

But I want to get the following

tensor([[0, 1],
        [2, 0],
        [3, 1]])

This is the indices of 9 in the 2D tensor.

Is there any approach to achieve this using pytorch?

Upvotes: 5

Views: 8322

Answers (4)

rayryeng
rayryeng

Reputation: 104503

As of PyTorch 2.2 and onwards, torch.unravel_index is now part of the library, so the conversion to NumPy as referenced by @mujjiga's answer is no longer required. Therefore, to borrow from that answer and perform this operation completely in PyTorch:

v, i = torch.topk(a.flatten(), 3)
indices = torch.column_stack(torch.unravel_index(i, a.shape))
print(indices)

We get:

tensor([[2, 0],
        [0, 1],
        [3, 1]])

Take note that because there are multiple locations that have the same maximum value, torch.topk is not sort stable, and was raised as an issue on the official PyTorch Github project, so the index order from torch.topk that is seen in the previous answer is not the same as appears here. As long as you are fine with the duplicate values being out of order, this should work.

Upvotes: 0

mujjiga
mujjiga

Reputation: 16876

v, i = torch.topk(a.flatten(), 3)
print (np.array(np.unravel_index(i.numpy(), a.shape)).T)

Output:

[[3 1]
 [2 0]
 [0 1]]
  1. Flatten and find top k
  2. Convert 1D indices to 2D using unravel_index

Upvotes: 13

dpetrini
dpetrini

Reputation: 1239

You can make some vector operations to filter according to your needs. In this case not using topk.

print(a)
tensor([[4, 9, 7, 4, 0],
    [8, 1, 3, 1, 0],
    [9, 8, 4, 4, 8],
    [0, 9, 4, 7, 8],
    [8, 8, 0, 1, 4]])

values, indices = torch.max(a,1)   # get max values, indices
temp= torch.zeros_like(values)     # temporary
temp[values==9]=1                  # fill temp where values are 9 (wished value)
seq=torch.arange(values.shape[0])  # create a helper sequence
new_seq=seq[temp>0]                # filter sequence where values are 9
new_temp=indices[new_seq]          # filter indices with sequence where values are 9
final = torch.stack([new_seq, new_temp], dim=1)  # stack both to get result

print(final)
tensor([[0, 1],
        [2, 0],
        [3, 1]])

Upvotes: 0

Artem Sobolev
Artem Sobolev

Reputation: 6069

You can flatten the original tensor, apply topk and then convert resultant scalar indices back to multidimensional indices with something like the following:

def descalarization(idx, shape):
    res = []
    N = np.prod(shape)
    for n in shape:
        N //= n
        res.append(idx // N)
        idx %= N
    return tuple(res)

Example:

torch.tensor([descalarization(k, a.size()) for k in torch.topk(a.flatten(), 5).indices])
# Returns 
# tensor([[3, 1],
#         [2, 0],
#         [0, 1],
#         [3, 4],
#         [2, 4]])

Upvotes: 1

Related Questions