Reputation: 570
I have the following pd.DataFrame
and list of columns:
col_list = ["med_a", "med_c"]
df = pd.DataFrame.from_dict({'med_a': [0, 0, 1, 0], 'med_b': [0, 0, 1, 1], 'med_c': [0, 1, 1, 0]})
print(df)
>>>
med_a med_b med_c
0 0 0 0
1 0 0 1
2 1 1 1
3 0 1 0
I want to make a new column (new_col
) that holds either True/False (or 0/1) if any of the columns in col_list
is equal to 1, for each row. So the result should become:
med_a med_b med_c new_col
0 0 0 0 0
1 0 0 1 1
2 1 1 1 1
3 0 1 0 0
I know how to select only those rows where at least one of the columns in is equal to 1, but that doesn't check only those columns in col_list
, and it doesn't create a new column:
df[(df== 1).any(axis=1)]
print(df)
>>>
med_a med_b med_c
1 0 0 1
2 1 1 1
3 0 1 1
How would I achieve the desired result? Any help is appreciated.
Upvotes: 1
Views: 39
Reputation: 35636
You're so close! Just filter the df with the col_list
before any
on axis=1 + astype(int)
.
import numpy as np
import pandas as pd
col_list = ["med_a", "med_c"]
df = pd.DataFrame.from_dict({'med_a': [0, 0, 1, 0],
'med_b': [0, 0, 1, 1],
'med_c': [0, 1, 1, 0]})
df['new_col'] = df[col_list].any(axis=1).astype(int)
print(df)
Or via np.where
:
df['new_col'] = np.where(df[col_list].any(axis=1), 1, 0)
med_a med_b med_c new_col
0 0 0 0 0
1 0 0 1 1
2 1 1 1 1
3 0 1 0 0
Timing information via perfplot:
np.where
is faster than astype(int)
up to 100,000 rows at which point they are about the same.
import numpy as np
import pandas as pd
import perfplot
np.random.seed(5)
col_list = ["med_a", "med_c"]
def gen_data(n):
return pd.DataFrame.from_dict({'med_a': np.random.choice([0, 1], size=n),
'med_b': np.random.choice([0, 1], size=n),
'med_c': np.random.choice([0, 1], size=n)})
def np_where(df):
df['new_col'] = np.where(df[col_list].any(axis=1), 1, 0)
return df
def astype_int(df):
df['new_col'] = df[col_list].any(axis=1).astype(int)
return df
if __name__ == '__main__':
out = perfplot.bench(
setup=gen_data,
kernels=[
np_where,
astype_int
],
labels=[
'np_where',
'astype_int'
],
n_range=[2 ** k for k in range(25)],
equality_check=None
)
out.save('perfplot_results.png', transparent=False)
Upvotes: 2