Reputation: 126
My code:
import numpy as np
N = 2
a = np.array([[0.5, 0.3, 0.2],
[0.2, 0.6, 0.2],
[0.3, 0.2, 0.7],
[np.nan, 0.2, 0.8],
[np.nan, np.nan, 0.8]
])
ind = np.argsort(np.where(np.isnan(a), -1, a), axis=1)[:, -N:]
a
Out[2]:
array([[ 0.5, 0.3, 0.2],
[ 0.2, 0.6, 0.2],
[ 0.3, 0.2, 0.7],
[ nan, 0.2, 0.8],
[ nan, nan, 0.8]])
ind
Out[3]:
array([[1, 0],
[2, 1],
[0, 2],
[1, 2],
[1, 2]], dtype=int64)
ind[:,1] being the highest and ind[:,0] second highest
Which is fine except the case with 2 nans in the last row. How to ignore second highest value if it is nan ? Desired output would be:
array([[1, 0],
[2, 1],
[0, 2],
[1, 2],
[nan, 2]], dtype=int64)
Bonus question: how to randomly break a tie in case of a[1,:] ?
Upvotes: 1
Views: 112
Reputation: 221564
Advanced-index
and check for NaNs
to give us a mask, which could be then used with np.where
to do the choosing, like so -
In [244]: a_ind = a[np.arange(ind.shape[0])[:,None],ind]
In [245]: mask = np.isnan(a_ind)
In [246]: np.where(mask, np.nan, ind)
Out[246]:
array([[ 1., 0.],
[ 2., 1.],
[ 0., 2.],
[ 1., 2.],
[ nan, 2.]])
Note that an array to have NaN
would be converted to float
dtype, hence the final output would also be of float
dtype.
Upvotes: 1