Reputation: 11
I have a dataframe such as follows:
id | first_name | last_name | sex | country |
---|---|---|---|---|
01 | John | Doe | Male | USA |
02 | John | Doe | Male | Canada |
03 | John | Doe | Male | Mexico |
04 | Mark | Kay | Male | Italy |
05 | John | Doe | Male | Spain |
06 | Mark | Kay | Male | France |
07 | John | Doe | Male | Peru |
08 | Mark | Kay | Male | India |
09 | Mark | Kay | Male | Laos |
10 | John | Doe | Male | Benin |
As you can see, the id and country columns are always unique, but the dataframe has duplicates based on columns first_name, last_name and sex. I want to be able to find such duplicates, and keep only upto 3 of them (ideally the last 3) and drop the rest. So after this operation, the resulting dataframe should look like this:
id | first_name | last_name | sex | country |
---|---|---|---|---|
05 | John | Doe | Male | Spain |
06 | Mark | Kay | Male | France |
07 | John | Doe | Male | Peru |
08 | Mark | Kay | Male | India |
09 | Mark | Kay | Male | Laos |
10 | John | Doe | Male | Benin |
How can I do this? Any help is appreciated!
I tried this:
window_spec = Window.partitionBy('first_name', 'last_name', 'sex').orderBy(F.desc('id'))
df_with_row_number = df.withColumn('row_number', F.row_number().over(window_spec))
filtered_df = df_with_row_number.filter('row_number <= 3')
result_df = filtered_df.drop('row_number')
This does give me the result I want. But I'm wondering if there's a more efficient way to achieve this, since it's a large dataset with a lot more rows and columns.
Upvotes: 1
Views: 44
Reputation: 866
You can use the groupby method in your dataframe. Here is an example :
import pandas as pd
# Your original dataframe
data = {
'id': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
'first_name': ['John', 'John', 'John', 'Mark', 'John', 'Mark', 'John', 'Mark', 'Mark', 'John'],
'last_name': ['Doe', 'Doe', 'Doe', 'Kay', 'Doe', 'Kay', 'Doe', 'Kay', 'Kay', 'Doe'],
'sex': ['Male', 'Male', 'Male', 'Male', 'Male', 'Male', 'Male', 'Male', 'Male', 'Male'],
'country': ['USA', 'Canada', 'Mexico', 'Italy', 'Spain', 'France', 'Peru', 'India', 'Laos', 'Benin']
}
df = pd.DataFrame(data)
# Sort the dataframe based on the 'id' column
df = df.sort_values(by='id')
# Keep only the last 3 occurrences for each group of first_name, last_name, and sex
result_df = df.groupby(['first_name', 'last_name', 'sex']).tail(3)
# Reset the index if needed
result_df = result_df.reset_index(drop=True)
# Display the resulting dataframe
print(result_df)
Which should output :
id first_name last_name sex country
0 5 John Doe Male Spain
1 6 Mark Kay Male France
2 7 John Doe Male Peru
3 8 Mark Kay Male India
4 9 Mark Kay Male Laos
5 10 John Doe Male Benin
Upvotes: 1