Reputation: 1221
In a pandas dataframe, for every row, I want to keep only the top N values and set everything else to 0. I can iterate through the rows and do it but I am sure python/pandas can do it elegantly in a single line.
For e.g.: for N = 2
Input:
A B C D
4 10 10 6
5 20 50 90
6 30 6 4
7 40 12 9
Output:
A B C D
0 10 10 0
0 0 50 90
6 30 6 0
0 40 12 0
Upvotes: 8
Views: 4703
Reputation: 13255
Using rank
with parameters axis=1
and method='min'
and ascending=False
as:
N = 2
df = df.mask(df.rank(axis=1, method='min', ascending=False) > N, 0)
Or using np.where
with pd.DataFrame
which is faster than mask
method:
df = pd.DataFrame(np.where(df.rank(axis=1,method='min',ascending=False)>N, 0, df),
columns=df.columns)
print(df)
A B C D
0 0 10 10 0
1 0 0 50 90
2 6 30 6 0
3 0 40 12 0
Step 1:
First we need to find what are the 2 smallest numbers in the row and also if there is a duplicate that need to be taken account. So, using axis=1
ranks across rows and duplicate values will be taken care by method='min'
and ascending = False
:
print(df.rank(axis=1, method='min', ascending=False))
A B C D
0 4.0 1.0 1.0 3.0
1 4.0 3.0 2.0 1.0
2 2.0 1.0 2.0 4.0
3 4.0 1.0 2.0 3.0
Step 2: Second we need to filter where the values is greater than (N) as per condition and then change those values using mask
:
print(df.rank(axis=1, method='min', ascending=False) > N)
A B C D
0 True False False True
1 True True False False
2 False False False True
3 True False False True
print(df.mask(df.rank(axis=1, method='min', ascending=False) > N, 0))
A B C D
0 0 10 10 0
1 0 0 50 90
2 6 30 6 0
3 0 40 12 0
Upvotes: 16
Reputation: 164623
You can use scipy.stats.rankdata
via np.apply_along_axis
, and feed to pd.DataFrame.where
:
from scipy.stats import rankdata
df[:] = df.where(np.apply_along_axis(rankdata, 1, df, method='max') > 2, 0)
print(df)
A B C D
0 0 10 10 0
1 0 0 50 90
2 6 30 6 0
3 0 40 12 0
pd.DataFrame.rank
is most efficient of solutions below; apply
+ lambda
perform worst.
from scipy.stats import rankdata
from heapq import nlargest
df = pd.concat([df]*100, ignore_index=True)
%timeit df.mask(df.rank(axis=1, method='min', ascending=False) > 2, 0) # 2.23 ms per loop
%timeit df.where(np.apply_along_axis(rankdata, 1, df, method='max') > 2, 0) # 45 ms per loop
%timeit df.where(df.apply(lambda x: x.isin(nlargest(2, x)), axis=1), 0) # 92.4 ms per loop
%timeit df.mask(~df.apply(lambda x: x.isin(x.nlargest(2)), axis=1), 0) # 274 ms per loop
Upvotes: 1
Reputation: 862501
Use:
N = 2
df = df.where(df.apply(lambda x: x.isin(x.nlargest(N)), axis=1), 0)
print (df)
A B C D
0 0 10 10 0
1 0 0 50 90
2 6 30 6 0
3 0 40 12 0
Or:
import heapq
N = 2
df = df.where(df.apply(lambda x: x.isin(heapq.nlargest(N, x)), axis=1), 0)
print (df)
A B C D
0 0 10 10 0
1 0 0 50 90
2 6 30 6 0
3 0 40 12 0
Upvotes: 3
Reputation: 1431
Use nlargest to get N largest numbers:
df.mask(~df.apply(lambda x: x.isin(x.nlargest(2)), axis=1), 0)
Outpu:
A B C D
0 0 10 10 0
1 0 0 50 90
2 6 30 6 0
3 0 40 12 0
Upvotes: 1