Mohammad
Mohammad

Reputation: 1013

How to remove duplicate records from PySpark DataFrame based on a condition?

Assume that I have a PySpark DataFrame like below:

# Prepare Data
data = [('Italy', 'ITA'), \
    ('China', 'CHN'), \
    ('China', None), \
    ('France', 'FRA'), \
    ('Spain', None), \
    ('Taiwan', 'TWN'), \
    ('Taiwan', None)
  ]

# Create DataFrame
columns = ['Name', 'Code']
df = spark.createDataFrame(data = data, schema = columns)
df.show(truncate=False)

enter image description here

As you can see, a few countries are repeated twice (China & Taiwan in the above example). I want to delete records that satisfy the following conditions:

  1. The column 'Name' is repeated more than once

AND

  1. The column 'Code' is Null.

Note that column 'Code' can be Null for countries which are not repeated, like Spain. I want to keep those records.

The expected output will be like:

Name Code
'Italy' 'ITA'
'China' 'CHN'
'France' 'FRA'
'Spain' Null
'Taiwan' 'TWN'

In fact, I want to have one record for every country. Any idea how to do that?

Upvotes: 0

Views: 1493

Answers (4)

Yuva
Yuva

Reputation: 3173

You can use window.PartitionBy to achieve your desired results:

from pyspark.sql import Window
import pyspark.sql.functions as f


df1 = df.select('Name', f.max('Code').over(Window.partitionBy('Name')).alias('Code')).distinct()
df1.show()

Output:

+------+----+
|  Name|Code|
+------+----+
| China| CHN|
| Spain|null|
|France| FRA|
|Taiwan| TWN|
| Italy| ITA|
+------+----+

Upvotes: 1

过过招
过过招

Reputation: 4244

In order to obtain non-null rows first, use the row_number window function to group by Name column and sort the Code column. Since null is considered the smallest in Spark order by, desc mode is used. Then take the first row of each group.

df = df.withColumn('rn', F.expr('row_number() over (partition by Name order by Code desc)')).filter('rn = 1').drop('rn')

Upvotes: 0

jaketclarke
jaketclarke

Reputation: 66

There will almost certainly be a cleverer way to do this, but for the sake of a lesson, what if you:

  1. made a new dataframe with just 'Name'
  2. dropped duplicates on that
  3. deleted records where Code = 'null' from initial table
  4. do a left join between new table and old table for 'Code'

I've added Australia with no country code just so you can see it works for that case as well


import pandas as pd

data = [('Italy', 'ITA'), \
    ('China', 'CHN'), \
    ('China', None), \
    ('France', 'FRA'), \
    ('Spain', None), \
    ('Taiwan', 'TWN'), \
    ('Taiwan', None), \
    ('Australia', None)
  ]

# Create DataFrame
columns = ['Name', 'Code']
df = pd.DataFrame(data = data, columns = columns)
print(df)

# get unique country names
uq_countries = df['Name'].drop_duplicates().to_frame()
print(uq_countries)

# remove None
non_na_codes = df.dropna()
print(non_na_codes)

# combine
final = pd.merge(left=uq_countries, right=non_na_codes, on='Name', how='left')
print(final)

Upvotes: 0

Debayan
Debayan

Reputation: 879

Here is one approach :

from pyspark.sql.functions import col
df = df.dropDuplicates(subset=["Name"],keep='first')

Upvotes: 0

Related Questions