dyno8426
dyno8426

Reputation: 1249

Extracting the top-k value-indices from a 1-D Tensor

Given a 1-D tensor in Torch (torch.Tensor), containing values which can be compared (say floating point), how can we extract the indices of the top-k values in that tensor?

Apart from the brute-force method, I am looking for some API call, that Torch/lua provides, which can perform this task efficiently.

Upvotes: 12

Views: 17302

Answers (3)

ChaosPredictor
ChaosPredictor

Reputation: 4071

You can use topk function.

for example:

import torch

t = torch.tensor([5.7, 1.4, 9.5, 1.6, 6.1, 4.3])

values,indices = t.topk(2)

print(values)
print(indices)

the result:

tensor([9.5000, 6.1000])
tensor([2, 4])

Upvotes: 8

deltheil
deltheil

Reputation: 16121

As of pull request #496 Torch now includes a built-in API named torch.topk. Example:

> t = torch.Tensor{9, 1, 8, 2, 7, 3, 6, 4, 5}

-- obtain the 3 smallest elements
> res = t:topk(3)
> print(res)
 1
 2
 3
[torch.DoubleTensor of size 3]

-- you can also get the indices in addition
> res, ind = t:topk(3)
> print(ind)
 2
 4
 6
[torch.LongTensor of size 3]

-- alternatively you can obtain the k largest elements as follow
-- (see the API documentation for more details)
> res = t:topk(3, true)
> print(res)
 9
 8
 7
[torch.DoubleTensor of size 3]

At the time of writing the CPU implementation follows a sort and narrow approach (there are plans to improve it in the future). That being said an optimized GPU implementation for cutorch is currently being reviewed.

Upvotes: 8

binarymax
binarymax

Reputation: 3405

Just loop through the tensor and run your compare:

require 'torch'

data = torch.Tensor({1,2,3,4,505,6,7,8,9,10,11,12})
idx  = 1
max  = data[1]

for i=1,data:size()[1] do
   if data[i]>max then
      max=data[i]
      idx=i
   end
end

print(idx,max)

--EDIT-- Responding to your edit: Use the torch.max operation documented here: https://github.com/torch/torch7/blob/master/doc/maths.md#torchmaxresval-resind-x-dim ...

y, i = torch.max(x, 1) returns the largest element in each column (across rows) of x, and a Tensor i of their corresponding indices in x

Upvotes: 0

Related Questions