Kruttika Swaminathan
Kruttika Swaminathan

Reputation: 11

Filter out n number of duplicates

I have a dataframe such as follows:

id first_name last_name sex country
01 John Doe Male USA
02 John Doe Male Canada
03 John Doe Male Mexico
04 Mark Kay Male Italy
05 John Doe Male Spain
06 Mark Kay Male France
07 John Doe Male Peru
08 Mark Kay Male India
09 Mark Kay Male Laos
10 John Doe Male Benin

As you can see, the id and country columns are always unique, but the dataframe has duplicates based on columns first_name, last_name and sex. I want to be able to find such duplicates, and keep only upto 3 of them (ideally the last 3) and drop the rest. So after this operation, the resulting dataframe should look like this:

id first_name last_name sex country
05 John Doe Male Spain
06 Mark Kay Male France
07 John Doe Male Peru
08 Mark Kay Male India
09 Mark Kay Male Laos
10 John Doe Male Benin

How can I do this? Any help is appreciated!

I tried this:

window_spec = Window.partitionBy('first_name', 'last_name', 'sex').orderBy(F.desc('id'))
df_with_row_number = df.withColumn('row_number', F.row_number().over(window_spec))
filtered_df = df_with_row_number.filter('row_number <= 3')
result_df = filtered_df.drop('row_number')

This does give me the result I want. But I'm wondering if there's a more efficient way to achieve this, since it's a large dataset with a lot more rows and columns.

Upvotes: 1

Views: 44

Answers (1)

Reda Bourial
Reda Bourial

Reputation: 866

You can use the groupby method in your dataframe. Here is an example :

import pandas as pd

# Your original dataframe
data = {
    'id': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
    'first_name': ['John', 'John', 'John', 'Mark', 'John', 'Mark', 'John', 'Mark', 'Mark', 'John'],
    'last_name': ['Doe', 'Doe', 'Doe', 'Kay', 'Doe', 'Kay', 'Doe', 'Kay', 'Kay', 'Doe'],
    'sex': ['Male', 'Male', 'Male', 'Male', 'Male', 'Male', 'Male', 'Male', 'Male', 'Male'],
    'country': ['USA', 'Canada', 'Mexico', 'Italy', 'Spain', 'France', 'Peru', 'India', 'Laos', 'Benin']
}

df = pd.DataFrame(data)

# Sort the dataframe based on the 'id' column
df = df.sort_values(by='id')

# Keep only the last 3 occurrences for each group of first_name, last_name, and sex
result_df = df.groupby(['first_name', 'last_name', 'sex']).tail(3)

# Reset the index if needed
result_df = result_df.reset_index(drop=True)

# Display the resulting dataframe
print(result_df)

Which should output :

   id first_name last_name   sex country
0   5       John       Doe  Male   Spain
1   6       Mark       Kay  Male  France
2   7       John       Doe  Male    Peru
3   8       Mark       Kay  Male   India
4   9       Mark       Kay  Male    Laos
5  10       John       Doe  Male   Benin

Upvotes: 1

Related Questions