Richard Gray
Richard Gray

Reputation: 127

How to use groupby and apply with DataFrames to set all values in a group column to 1 if one of the column values is 1?

I have a DataFrame with the following structure:

enter image description here

I want to transform the DataFrame so that for every unique user_id, if a column contains a 1, then the whole column should contain 1s for that user_id. Assume that I don't know all of the column names in advance. Based on the above input, the output would be:

enter image description here

I have the following code (please excuse how unsuccinct it is):

df = df.groupby('user_id').apply(self.transform_columns)

def transform_columns(self, x):
    x.apply(self.transform)

def transform(self, x):
    if 1 in x:
        for element in x:
            element = 1
        var = x

At the point of the transform function, x is definitely a series. For some reason this code is returning an empty DataFrame. Btw if you also know a way of excluding certain columns from the transformation (e.g. user_id) that would be great. Please help.

I'm going to explain how I transformed the data into the initial state for the input, as after attempting Jezrael's answer, I am getting a KeyError on the 'user_id' column (which definitely exists in the df). The initial state of the data was as below:

enter image description here

I transformed it to the state shown in the first image in the question with the following code:

df2 = self.add_support_columns(df)
df = df.join(df2)

def add_support_columns(self, df):
    df['pivot_column'] = df.apply(self.get_col_name, axis=1)
    df['flag'] = 1
    df = df.pivot(index='user_id', columns='pivot_column')['flag']
    df.reset_index(inplace=True)
    df = df.fillna(0)

    return df

Upvotes: 2

Views: 159

Answers (1)

jezrael
jezrael

Reputation: 863541

You can use set_index + groupby + transform with any + reset_index:

It working because 1s are in any function processes like Trues - so if at least one 1 it return 1 else 0.

df = pd.DataFrame({
    'user_id' : [33,33,33,33,22,22],
    'q1' : [1,0,0,0,0,0],
    'q2' : [0,0,0,0,1,0],    
    'q3' : [0,1,0,0,0,1],     
    })
df = df.reindex_axis(['user_id','q1','q2','q3'], 1)
print (df)
   user_id  q1  q2  q3
0       33   1   0   0
1       33   0   0   1
2       33   0   0   0
3       33   0   0   0
4       22   0   1   0
5       22   0   0   1

df = df.set_index('user_id')
       .groupby('user_id') # or groupby(level=0)
       .transform(lambda x: 1 if x.any() else 0)
       .reset_index()
print (df)
   user_id  q1  q2  q3
0       33   1   0   1
1       33   1   0   1
2       33   1   0   1
3       33   1   0   1
4       22   0   1   1
5       22   0   1   1

Solution with join:

df = df[['user_id']].join(df.groupby('user_id').transform(lambda x: 1 if x.any() else 0))
print (df)
   user_id  q1  q2  q3
0       33   1   0   1
1       33   1   0   1
2       33   1   0   1
3       33   1   0   1
4       22   0   1   1
5       22   0   1   1

EDIT:

More dynamic solution with difference + reindex_axis:

#select only some columns
cols = ['q1','q2']
#all another columns are not transforming
cols2 = df.columns.difference(cols)

df1 = df[cols2].join(df.groupby('user_id')[cols].transform(lambda x: 1 if x.any() else 0))
#if need same order of columns as original
df1 = df1.reindex_axis(df.columns, axis=1)
print (df1)
   user_id  q1  q2  q3
0       33   1   0   0
1       33   1   0   1
2       33   1   0   0
3       33   1   0   0
4       22   0   1   0
5       22   0   1   1

Also logic can be inverted:

#select only columns which are not transforming
cols = ['user_id']
#all another columns are transforming
cols2 = df.columns.difference(cols)

df1 = df[cols].join(df.groupby('user_id')[cols2].transform(lambda x: 1 if x.any() else 0))
df1 = df1.reindex_axis(df.columns, axis=1)
print (df1)
   user_id  q1  q2  q3
0       33   1   0   1
1       33   1   0   1
2       33   1   0   1
3       33   1   0   1
4       22   0   1   1
5       22   0   1   1

EDIT:

More efficient solution is return only boolean mask and then convert to int:

df1 = df.groupby('user_id').transform('any').astype(int)

Timings:

In [170]: %timeit (df.groupby('user_id').transform(lambda x: 1 if x.any() else 0))
1 loop, best of 3: 514 ms per loop

In [171]: %timeit (df.groupby('user_id').transform('any').astype(int))
10 loops, best of 3: 84 ms per loop

Sample for timings:

np.random.seed(123)
N = 1000
df = pd.DataFrame(np.random.choice([0,1], size=(N, 3)),
                   index=np.random.randint(1000, size=N))
df.index.name = 'user_id'
df = df.add_prefix('q').reset_index()
#print (df)

Upvotes: 1

Related Questions