Chris C
Chris C

Reputation: 619

PySpark - Get indices of duplicate rows

Let's say I have a PySpark data frame, like so:

+--+--+--+--+
|a |b |c |d |
+--+--+--+--+
|1 |0 |1 |2 |
|0 |2 |0 |1 |
|1 |0 |1 |2 |
|0 |4 |3 |1 |
+--+--+--+--+

How can I create a column marking all of the duplicate rows, like so:

+--+--+--+--+--+
|a |b |c |d |e |
+--+--+--+--+--+
|1 |0 |1 |2 |1 |
|0 |2 |0 |1 |0 |
|1 |0 |1 |2 |1 |
|0 |4 |3 |1 |0 |
+--+--+--+--+--+

I attempted it using the groupBy and aggregate functions to no avail.

Upvotes: 7

Views: 27816

Answers (5)

I think the pandas_udf can handle this in an easier way. Firstly, you need to create a pandas UDF which takes a Series and returns True for the duplicated rows. Then, simply use withColumn to mark the duplicated rows. Here is my suggested code:

@pandas_udf('boolean')
def duplicate_finder(s: pd.Series) -> pd.Series:
    return s.duplicated(keep=False)

df.withColumn('Duplicated', duplicate_finder('DESIRED_COLUMN')).show()

Upvotes: 2

Shyam Gupta
Shyam Gupta

Reputation: 513

df1=df_interr.groupBy("Item_group","Item_name","price").count().filter("count > 1")

Upvotes: 0

Ramesh Maharjan
Ramesh Maharjan

Reputation: 41957

Define a window function to check whether the count of rows when grouped by all columns is greater than 1. If yes, its a duplicate (1) else not duplicate (0)

allColumns = df.columns
import sys
from pyspark.sql import functions as f
from pyspark.sql import window as w
windowSpec = w.Window.partitionBy(allColumns).rowsBetween(-sys.maxint, sys.maxint)

df.withColumn('e', f.when(f.count(f.col('d')).over(windowSpec) > 1, f.lit(1)).otherwise(f.lit(0))).show(truncate=False) 

which should give you

+---+---+---+---+---+
|a  |b  |c  |d  |e  |
+---+---+---+---+---+
|1  |0  |1  |2  |1  |
|1  |0  |1  |2  |1  |
|0  |2  |0  |1  |0  |
|0  |4  |3  |1  |0  |
+---+---+---+---+---+

I hope the answer is helpful

Updated

As @pault commented, you can eliminate when, col and lit by casting the boolean to integer:

df.withColumn('e', (f.count('*').over(windowSpec) > 1).cast('int')).show(truncate=False)

Upvotes: 8

pault
pault

Reputation: 43494

Just to expand on my comment:

You can group by all of the columns and use pyspark.sql.functions.count() to determine if a column is duplicated:

import pyspark.sql.functions as f
df.groupBy(df.columns).agg((f.count("*")>1).cast("int").alias("e")).show()
#+---+---+---+---+---+
#|  a|  b|  c|  d|  e|
#+---+---+---+---+---+
#|  1|  0|  1|  2|  1|
#|  0|  2|  0|  1|  0|
#|  0|  4|  3|  1|  0|
#+---+---+---+---+---+

Here we use count("*") > 1 as the aggregate function, and cast the result to an int. The groupBy() will have the consequence of dropping the duplicate rows. Depending on your needs, this may be sufficient.

However, if you'd like to keep all of the rows, you can use a Window function like shown in the other answers OR you can use a join():

df.join(
    df.groupBy(df.columns).agg((f.count("*")>1).cast("int").alias("e")),
    on=df.columns,
    how="inner"
).show()
#+---+---+---+---+---+
#|  a|  b|  c|  d|  e|
#+---+---+---+---+---+
#|  1|  0|  1|  2|  1|
#|  1|  0|  1|  2|  1|
#|  0|  2|  0|  1|  0|
#|  0|  4|  3|  1|  0|
#+---+---+---+---+---+

Here we inner join the original dataframe with the one that is the result of the groupBy() above on all of the columns.

Upvotes: 14

Kaushal
Kaushal

Reputation: 3367

Partition your dataframe with all the columns and than apply dense_rank.

import sys
from pyspark.sql.functions import dense_rank
from pyspark.sql import window as w

df.withColumn('e', dense_rank().over(w.Window.partitionBy(df.columns))).show()

Upvotes: 1

Related Questions