kev
kev

Reputation: 2881

Pandas: Conditional groupby and max

I have the following dataframe-

my_df-

user_id |  spend |  transaction_id |
--------+--------+-----------------|
1       |   45   |        12       |
2       |   33   |        45       |
3       |   12   |        33       |
1       |   22   |        56       |
1       |   77   |        99       |

In order to de-duplicate the above table, my goal was to get the greatest transaction_id given a user_id. So I did the following-

deduped_df = my_df.groupby("user_id")[["transaction_id"]].max().reset_index()

which gave me the following result -

user_id |  transaction_id |
--------+-----------------|
1       |   99            |
2       |   45            |
3       |   33            |

Now, what if I want to apply the above operation to just the first 100 user_id and discard the other rows? Also, I lost my spend column. How do I retain that column in the above operation?

I want my final result to look like this -

user_id |  spend |  transaction_id |
--------+--------+-----------------|
1       |   77   |        99       |
2       |   33   |        45       |
3       |   12   |        33       |

Upvotes: 1

Views: 469

Answers (1)

jezrael
jezrael

Reputation: 862921

You can filter first 100 rows user_id - values are sorted, because groupby sorting by default:

deduped_df = my_df.groupby("user_id")[["transaction_id"]].max().head(100).reset_index()

If need sorting manually:

my_df = my_df.sort_values('user_id')
deduped_df = (my_df.groupby("user_id", sort=False)[["transaction_id"]]
                   .max()
                   .head(100)
                   .reset_index())

If need first 100 user_id without sorting add parameter sort=False to groupby:

deduped_df = (my_df.groupby("user_id", sort=False)[["transaction_id"]]
                   .max()
                   .head(100)
                   .reset_index())

Also, I lost my spend column. How do I retain that column in the above operation?

It is expected, because this column is not used in groupby and also is not used any aggregate function for it, so lost.

So possible solution is add some aggregate function, like sum for spend column:

deduped_df = (my_df.groupby("user_id", sort=False)
                   .agg({"transaction_id":'max', 'spend':'sum'})
                   .head(100)
                   .reset_index())

EDIT: If need rows with first 100 maximum use DataFrameGroupBy.idxmax for indices and select by DataFrame.loc:

deduped_df = my_df.loc[my_df.groupby("user_id")["transaction_id"].idxmax().head(100)]
                   
             

Upvotes: 1

Related Questions