Reputation: 619
Let's say I have a PySpark data frame, like so:
+--+--+--+--+
|a |b |c |d |
+--+--+--+--+
|1 |0 |1 |2 |
|0 |2 |0 |1 |
|1 |0 |1 |2 |
|0 |4 |3 |1 |
+--+--+--+--+
How can I create a column marking all of the duplicate rows, like so:
+--+--+--+--+--+
|a |b |c |d |e |
+--+--+--+--+--+
|1 |0 |1 |2 |1 |
|0 |2 |0 |1 |0 |
|1 |0 |1 |2 |1 |
|0 |4 |3 |1 |0 |
+--+--+--+--+--+
I attempted it using the groupBy and aggregate functions to no avail.
Upvotes: 7
Views: 27816
Reputation: 965
I think the pandas_udf can handle this in an easier way. Firstly, you need to create a pandas UDF which takes a Series and returns True for the duplicated rows. Then, simply use withColumn to mark the duplicated rows. Here is my suggested code:
@pandas_udf('boolean')
def duplicate_finder(s: pd.Series) -> pd.Series:
return s.duplicated(keep=False)
df.withColumn('Duplicated', duplicate_finder('DESIRED_COLUMN')).show()
Upvotes: 2
Reputation: 513
df1=df_interr.groupBy("Item_group","Item_name","price").count().filter("count > 1")
Upvotes: 0
Reputation: 41957
Define a window
function to check whether the count
of rows when grouped by all columns is greater than 1. If yes, its a duplicate (1) else not duplicate (0)
allColumns = df.columns
import sys
from pyspark.sql import functions as f
from pyspark.sql import window as w
windowSpec = w.Window.partitionBy(allColumns).rowsBetween(-sys.maxint, sys.maxint)
df.withColumn('e', f.when(f.count(f.col('d')).over(windowSpec) > 1, f.lit(1)).otherwise(f.lit(0))).show(truncate=False)
which should give you
+---+---+---+---+---+
|a |b |c |d |e |
+---+---+---+---+---+
|1 |0 |1 |2 |1 |
|1 |0 |1 |2 |1 |
|0 |2 |0 |1 |0 |
|0 |4 |3 |1 |0 |
+---+---+---+---+---+
I hope the answer is helpful
Updated
As @pault commented, you can eliminate when
, col
and lit
by casting the boolean
to integer
:
df.withColumn('e', (f.count('*').over(windowSpec) > 1).cast('int')).show(truncate=False)
Upvotes: 8
Reputation: 43494
Just to expand on my comment:
You can group by all of the columns and use pyspark.sql.functions.count()
to determine if a column is duplicated:
import pyspark.sql.functions as f
df.groupBy(df.columns).agg((f.count("*")>1).cast("int").alias("e")).show()
#+---+---+---+---+---+
#| a| b| c| d| e|
#+---+---+---+---+---+
#| 1| 0| 1| 2| 1|
#| 0| 2| 0| 1| 0|
#| 0| 4| 3| 1| 0|
#+---+---+---+---+---+
Here we use count("*") > 1
as the aggregate function, and cast the result to an int
. The groupBy()
will have the consequence of dropping the duplicate rows. Depending on your needs, this may be sufficient.
However, if you'd like to keep all of the rows, you can use a Window
function like shown in the other answers OR you can use a join()
:
df.join(
df.groupBy(df.columns).agg((f.count("*")>1).cast("int").alias("e")),
on=df.columns,
how="inner"
).show()
#+---+---+---+---+---+
#| a| b| c| d| e|
#+---+---+---+---+---+
#| 1| 0| 1| 2| 1|
#| 1| 0| 1| 2| 1|
#| 0| 2| 0| 1| 0|
#| 0| 4| 3| 1| 0|
#+---+---+---+---+---+
Here we inner join the original dataframe with the one that is the result of the groupBy()
above on all of the columns.
Upvotes: 14
Reputation: 3367
Partition your dataframe with all the columns and than apply dense_rank.
import sys
from pyspark.sql.functions import dense_rank
from pyspark.sql import window as w
df.withColumn('e', dense_rank().over(w.Window.partitionBy(df.columns))).show()
Upvotes: 1