klubow
klubow

Reputation: 126

Get indices of N highest values in numpy array

My code:

import numpy as np
N = 2
a = np.array([[0.5, 0.3, 0.2],
              [0.2, 0.6, 0.2], 
              [0.3, 0.2, 0.7],
              [np.nan, 0.2, 0.8],                      
              [np.nan, np.nan, 0.8]                      
              ])

ind = np.argsort(np.where(np.isnan(a), -1, a), axis=1)[:, -N:]


a
Out[2]: 
array([[ 0.5,  0.3,  0.2],
       [ 0.2,  0.6,  0.2],
       [ 0.3,  0.2,  0.7],
       [ nan,  0.2,  0.8],
       [ nan,  nan,  0.8]])

ind
Out[3]: 
array([[1, 0],
       [2, 1],
       [0, 2],
       [1, 2],
       [1, 2]], dtype=int64)

ind[:,1] being the highest and ind[:,0] second highest

Which is fine except the case with 2 nans in the last row. How to ignore second highest value if it is nan ? Desired output would be:

array([[1, 0],
       [2, 1],
       [0, 2],
       [1, 2],
       [nan, 2]], dtype=int64)

Bonus question: how to randomly break a tie in case of a[1,:] ?

Upvotes: 1

Views: 112

Answers (1)

Divakar
Divakar

Reputation: 221564

Advanced-index and check for NaNs to give us a mask, which could be then used with np.where to do the choosing, like so -

In [244]: a_ind = a[np.arange(ind.shape[0])[:,None],ind]

In [245]: mask = np.isnan(a_ind)

In [246]: np.where(mask, np.nan, ind)
Out[246]: 
array([[  1.,   0.],
       [  2.,   1.],
       [  0.,   2.],
       [  1.,   2.],
       [ nan,   2.]])

Note that an array to have NaN would be converted to float dtype, hence the final output would also be of float dtype.

Upvotes: 1

Related Questions