Reputation: 1516
I need to find all indices where the maximum value (per row) is obtained in a Pandas DataFrame. For instance, if I have a dataFrame like this:
cat1 cat2 cat3
0 0 2 2
1 3 0 1
2 1 1 0
then the method I am looking for would yield a result like:
[['cat2', 'cat3'],
['cat1'],
['cat1', 'cat2']]
This is a list of lists, but some other data structure is also okay.
I cannot use df.idxmax(axis=1)
, because it only yields the first maximum.
Upvotes: 13
Views: 6320
Reputation: 77027
You could do
In [2560]: cols = df.columns.values
In [2561]: vals = df.values
In [2562]: [cols[v].tolist() for v in vals == vals.max(1)[:, None]]
Out[2562]: [['cat2', 'cat3'],
['cat1'],
['cat1', 'cat2']]
Update
Here is a full example:
import pandas as pd
import numpy as np
np.random.seed(400)
df = pd.DataFrame({
'a': np.random.randint(0,3,size=10),
'b': np.random.randint(0,3,size=10),
'c': np.random.randint(0,5,size=10),
})
print(df)
out = [df.columns[i].tolist() for i in df.values == df.max(axis=1)[:,None]]
for i in out:
print(i)
Returning from print(df)
:
a b c
0 0 1 4
1 2 2 4
2 1 1 1
3 0 1 3
4 2 2 1
5 1 1 1
6 0 2 4
7 2 0 2
8 2 1 3
9 2 2 4
And from print(out)
:
['c']
['c']
['a', 'b', 'c']
['c']
['a', 'b']
['a', 'b', 'c']
['c']
['a', 'c']
['c']
['c']
Upvotes: 1
Reputation: 880927
Here is the information, in a different data structure:
In [8]: df = pd.DataFrame({'cat1':[0,3,1], 'cat2':[2,0,1], 'cat3':[2,1,0]})
In [9]: df
Out[9]:
cat1 cat2 cat3
0 0 2 2
1 3 0 1
2 1 1 0
[3 rows x 3 columns]
In [10]: rowmax = df.max(axis=1)
The max values are indicated by True values:
In [82]: df.values == rowmax[:,None]
Out[82]:
array([[False, True, True],
[ True, False, False],
[ True, True, False]], dtype=bool)
np.where
returns the indices where the DataFrame above is True.
In [84]: np.where(df.values == rowmax[:,None])
Out[84]: (array([0, 0, 1, 2, 2]), array([1, 2, 0, 0, 1]))
The first array indicates index values for axis=0
, the second array for axis=1
. There are 5 values in each array since there are five locations that are True.
You could use itertools.groupby
to build the list of lists you posted, though perhaps you don't need this given the data structures above:
In [46]: import itertools as IT
In [47]: import operator
In [48]: idx = np.where(df.values == rowmax[:,None])
In [49]: groups = IT.groupby(zip(*idx), key=operator.itemgetter(0))
In [50]: [[df.columns[j] for i, j in grp] for k, grp in groups]
Out[50]: [['cat1', 'cat1'], ['cat2'], ['cat3', 'cat3']]
Upvotes: 4