Reputation: 728
I want to extract the indexes for the 3 highest values for each row in a pandas dataframe. Right now I am using
top3df = df.apply(lambda x: pd.Series(x.nlargest(3).index), axis=1)
Unfortunately, this function is quite costly and runs on my example dataset of 2,000,000rows x 80columns for about 30 minutes. Is there any faster way?
Upvotes: 1
Views: 1061
Reputation: 29635
you can use np.sort
with axis=1, use [:,::-1]
to reverse the order of the sort and then [:,:3]
to select the first 3 columns of the array. Then recreate the dataframe
#input
import numpy as np
np.random.seed(3)
df = pd.DataFrame(np.random.randint(0,100,100).reshape(10, 10),
columns=list('abcdefghij'))
# sort
top3 = pd.DataFrame(np.sort(df, axis=1)[:, ::-1][:,:3])
print(top3)
0 1 2
0 74 72 56
1 96 93 81
2 90 90 69
3 97 79 62
4 94 78 64
5 85 71 63
6 99 91 80
7 96 95 61
8 91 90 74
9 88 60 56
EDIT: OP changed the question to extract the columns' names of the top 3 values per row, that can be done with argsort
and slicing the columns names:
print(pd.DataFrame(df.columns.to_numpy()
[np.argsort(df.to_numpy(), axis=1)][:, -1:-4:-1]))
Upvotes: 4