Reputation: 43
How can I compare all columns in a DataFrame with each other, removing columns that are 'less than' other columns according to an arbitrary comparison function (where the comparison function is transitive).
e.g. If I have the DataFrame
0 1 2 3 4 5
0 4 1 1 3 6 2
1 4 2 2 7 2 6
2 4 3 3 3 6 2
3 4 8 3 7 2 6
and my function is "column A is < column B if A[i] < B[i] for all rows i", the result would be
0 1 3 4
0 4 1 3 6
1 4 2 7 2
2 4 3 3 6
3 4 8 7 2
dropping column 2 as (4>1, 4>2, 4>3, 4>3) and column 5 as (3>2, 7>6, 3>2, 7>6).
My initial/obvious-but-slow approach is that this can be done in n^2 time with (pseudo code; I haven't done much Pandas programming before... answers with real code would be appreciated)
for i in range(0, n):
for j in range(0, n):
if my_less_than_function(col(i), col(j)):
# i < j
drop col(i)
If the less than function is transitive I could also remember which columns I've already dropped and skip them when iterating i and j. I could also iterate j in range(i + 1, n) if my comparison function returned (-1, 0, 1) for (less, equal, more) instead of (true, false) for (less, equal or more)
Note the comparison function may not be pairwise by row e.g. it could be sum(col A) < sum(col B) or number_of_primes_in(col A) < number_of_primes_in(col B)
Thank you
Upvotes: 2
Views: 6628
Reputation: 294488
Setup
df = pd.DataFrame([
[4, 1, 1, 3, 6, 2],
[4, 2, 2, 7, 2, 6],
[4, 3, 3, 3, 6, 2],
[4, 8, 3, 7, 2, 6]
], columns=list('abcdef'))
print(df)
a b c d e f
0 4 1 1 3 6 2
1 4 2 2 7 2 6
2 4 3 3 3 6 2
3 4 8 3 7 2 6
numpy
broadcasting
For your definition of less_than
we can use numpy
v = df.values
lt = pd.DataFrame((v.T[:, None] < v.T).all(-1), df.columns, df.columns)
print(lt)
a b c d e f
a False False False False False False
b False False False False False False
c True False False False False False
d False False False False False False
e False False False False False False
f False False False True False False
You can pull out specific columns in the following way:
all columns that are >= 'f'
df.loc[:, lt.loc['f']]
d
0 3
1 7
2 3
3 7
all columns that < 'f'
df.loc[:, ~lt.loc['f']]
a b c e f
0 4 1 1 6 2
1 4 2 2 2 6
2 4 3 3 6 2
3 4 8 3 2 6
Upvotes: 4
Reputation: 210882
try this:
In [278]: df
Out[278]:
0 1 2 3 4 5
0 4 1 1 3 6 2
1 4 2 2 7 2 6
2 4 3 3 3 6 2
3 4 8 3 7 2 6
In [279]: cols2drop = [col for col in df.columns if df.T.gt(df[col]).all(1).any()]
In [280]: cols2drop
Out[280]: [2, 5]
In [282]: df = df.drop(cols2drop, 1)
In [283]: df
Out[283]:
0 1 3 4
0 4 1 3 6
1 4 2 7 2
2 4 3 3 6
3 4 8 7 2
Explanation:
In [286]: df.T.gt(df[2])
Out[286]:
0 1 2 3
0 True True True True
1 False False False True
2 False False False False
3 True True False True
4 True False True False
5 True True False True
In [287]: df.T.gt(df[2]).all(1)
Out[287]:
0 True
1 False
2 False
3 False
4 False
5 False
dtype: bool
In [288]: df.T.gt(df[2]).all(1).any()
Out[288]: True
Upvotes: 2