Jake Niederer
Jake Niederer

Reputation: 91

Grouping Pandas Dataframe by Elements in Column of Lists

I am attempting to get the aggregate sum of columns within a pandas dataframe by grouping by elements in a column of lists. I will create a dummy dataset to represent the data I am working with:

preg_df = pd.DataFrame({'Diag_Codes': [['O1414', 'O4103X0', 'O365930', 'O76'], 
                                       ['O200', 'N3000', 'M545', 'R102', 'R110', 'Z3A01'],
                                       ['O365922', 'O30032', 'O09512', 'Z3A26'], 
                                       ['O2341', 'O200', 'Z3A01'], 
                                       ['O209', 'Z3A01']], 
                        'First_Trimester': [0, 1, 0, 1, 1], 
                        'Second_Trimester': [0, 0, 1, 0, 0], 
                        'Third_Trimester': [1, 0, 0, 0, 0]})

I would like to create a new dataframe from this data that is grouped by diagnosis codes contained within the 'Diag_Codes' column of the preg_df. I have been able to accomplish this with the following for loop:

# Create a list of unique diagnosis codes from the preg_df dataframe
diagnoses = list(set([item for sublist in preg_df.Diag_Codes.tolist() for item in sublist]))

diag_dfs = []

for i in diagnoses:
    
    diag_indices = []
    diag_df = pd.DataFrame()
    
    # Get the indices at which the diagnosis code exists within the 'Diag_Codes' column
    [diag_indices.append(index) for index, row in preg_df.iterrows() if i in preg_df.loc[index, 'Diag_Codes']]
    
    # Subset the dataframe to obtain only records in which the diagnosis code exists within 'Diag_Codes' column
    diag_df = preg_df.loc[diag_indices, 'First_Trimester':]
    diag_df['Diag_Code'] = i
    diag_df['Total_Cases'] = len(diag_indices)
    
    # Group by diagnosis code and the total number of cases and get the aggregate sum of all other columns
    diag_df = diag_df.groupby(['Diag_Code', 'Total_Cases']).sum()
    diag_dfs.append(diag_df)
    
diag_data = pd.concat(diag_dfs).sort_values(by=['Total_Cases'], ascending=False)
diag_data.head()

The above for loop does produce the dataframe I am interested in creating from the original dataset, however, this method does not scale for a large dataset. The actual dataframe I am working with has approximately 5 million rows and contains tens of thousands of unique diagnosis codes. Therefore, it is not feasible for me to obtain the desired dataframe in which I am grouping by diagnosis code by using the for loop I have shared above. Is there a more efficient way for me to accomplish this desired output working with a much larger dataset?

Upvotes: 2

Views: 44

Answers (2)

Scott Boston
Scott Boston

Reputation: 153460

Let's try:

preg_df.explode('Diag_Codes').groupby('Diag_Codes').sum()

Output:

            First_Trimester  Second_Trimester  Third_Trimester
Diag_Codes                                                    
M545                      1                 0                0
N3000                     1                 0                0
O09512                    0                 1                0
O1414                     0                 0                1
O200                      2                 0                0
O209                      1                 0                0
O2341                     1                 0                0
O30032                    0                 1                0
O365922                   0                 1                0
O365930                   0                 0                1
O4103X0                   0                 0                1
O76                       0                 0                1
R102                      1                 0                0
R110                      1                 0                0
Z3A01                     3                 0                0
Z3A26                     0                 1                0

Upvotes: 3

Vioxini
Vioxini

Reputation: 102

This should work. It will stack every list diag_code with their own columns and rows, therefore it will be easier to opeare with them.

In:

diag_codes = pd.DataFrame(preg_df["Diag_Codes"].tolist()).stack()
diag_codes.index = diag_codes.index.droplevel(-1)
diag_codes.name = "diag_codes"
grouped_codes = preg_df.join(diag_codes).groupby('diag_codes').sum()
grouped_codes

Output:

Out:

    First_Trimester     Second_Trimester    Third_Trimester
diag_codes          
M545    1   0   0
N3000   1   0   0
O09512  0   1   0
O1414   0   0   1
O200    2   0   0
O209    1   0   0
O2341   1   0   0
O30032  0   1   0
O365922     0   1   0
O365930     0   0   1
O4103X0     0   0   1
O76     0   0   1
R102    1   0   0
R110    1   0   0
Z3A01   3   0   0
Z3A26   0   1   0

Tell me to make any fixes if necessary or you can go ahead with this. Tell me the dimensions of te database so I can see if this is well optimized. Remember, always try to use built in functions and use "for" loops as your last resort.

Upvotes: 1

Related Questions