Reputation: 5412
I have a 2D tensor and I want to get the indices of the top k values. I know about pytorch's topk function. The problem with pytorch's topk function is, it computes the topk values over some dimension. I want to get topk values over both dimensions.
For example for the following tensor
a = torch.tensor([[4, 9, 7, 4, 0],
[8, 1, 3, 1, 0],
[9, 8, 4, 4, 8],
[0, 9, 4, 7, 8],
[8, 8, 0, 1, 4]])
pytorch's topk function will give me the following.
values, indices = torch.topk(a, 3)
print(indices)
# tensor([[1, 2, 0],
# [0, 2, 1],
# [0, 1, 4],
# [1, 4, 3],
# [1, 0, 4]])
But I want to get the following
tensor([[0, 1],
[2, 0],
[3, 1]])
This is the indices of 9 in the 2D tensor.
Is there any approach to achieve this using pytorch?
Upvotes: 5
Views: 8322
Reputation: 104503
As of PyTorch 2.2 and onwards, torch.unravel_index
is now part of the library, so the conversion to NumPy as referenced by @mujjiga's answer is no longer required. Therefore, to borrow from that answer and perform this operation completely in PyTorch:
v, i = torch.topk(a.flatten(), 3)
indices = torch.column_stack(torch.unravel_index(i, a.shape))
print(indices)
We get:
tensor([[2, 0],
[0, 1],
[3, 1]])
Take note that because there are multiple locations that have the same maximum value, torch.topk
is not sort stable, and was raised as an issue on the official PyTorch Github project, so the index order from torch.topk
that is seen in the previous answer is not the same as appears here. As long as you are fine with the duplicate values being out of order, this should work.
Upvotes: 0
Reputation: 16876
v, i = torch.topk(a.flatten(), 3)
print (np.array(np.unravel_index(i.numpy(), a.shape)).T)
Output:
[[3 1]
[2 0]
[0 1]]
unravel_index
Upvotes: 13
Reputation: 1239
You can make some vector operations to filter according to your needs. In this case not using topk.
print(a)
tensor([[4, 9, 7, 4, 0],
[8, 1, 3, 1, 0],
[9, 8, 4, 4, 8],
[0, 9, 4, 7, 8],
[8, 8, 0, 1, 4]])
values, indices = torch.max(a,1) # get max values, indices
temp= torch.zeros_like(values) # temporary
temp[values==9]=1 # fill temp where values are 9 (wished value)
seq=torch.arange(values.shape[0]) # create a helper sequence
new_seq=seq[temp>0] # filter sequence where values are 9
new_temp=indices[new_seq] # filter indices with sequence where values are 9
final = torch.stack([new_seq, new_temp], dim=1) # stack both to get result
print(final)
tensor([[0, 1],
[2, 0],
[3, 1]])
Upvotes: 0
Reputation: 6069
You can flatten
the original tensor, apply topk
and then convert resultant scalar indices back to multidimensional indices with something like the following:
def descalarization(idx, shape):
res = []
N = np.prod(shape)
for n in shape:
N //= n
res.append(idx // N)
idx %= N
return tuple(res)
Example:
torch.tensor([descalarization(k, a.size()) for k in torch.topk(a.flatten(), 5).indices])
# Returns
# tensor([[3, 1],
# [2, 0],
# [0, 1],
# [3, 4],
# [2, 4]])
Upvotes: 1