Behzad Rowshanravan
Behzad Rowshanravan

Reputation: 247

Labelling duplicates in PySpark

I am trying to label the duplicates in my PySpark DataFrame based on their group, while having the full length data frame. Below is an example code.

data= [
    ("A", "2018-01-03"),
    ("A", "2018-01-03"),
    ("A", "2018-01-03"),
    ("B", "2019-01-03"),
    ("B", "2019-01-03"),
    ("B", "2019-01-03"),
    ("C", "2020-01-03"),
    ("C", "2020-01-03"),
    ("C", "2020-01-03"),
]

from pyspark.sql import SparkSession
import pyspark.sql.functions as F

spark= SparkSession.builder.getOrCreate()

df= spark.createDataFrame(data=data, schema=["Group", "Date"])
df= df.withColumn("Date", F.to_date("Date", "yyyy-MM-dd"))


from pyspark.sql import Window
windowSpec= Window.partitionBy("Group").orderBy(F.asc("Date"))

df.withColumn("group_number", F.dense_rank().over(windowSpec)).orderBy("Date").show()

This is my current output and although it is correct since the code ranks "Date" based on its group but that was not my desired outcome.

+-----+----------+------------+
|Group|      Date|group_number|
+-----+----------+------------+
|    A|2018-01-03|           1|
|    A|2018-01-03|           1|
|    A|2018-01-03|           1|
|    B|2019-01-03|           1|
|    B|2019-01-03|           1|
|    B|2019-01-03|           1|
|    C|2020-01-03|           1|
|    C|2020-01-03|           1|
|    C|2020-01-03|           1|
+-----+----------+------------+

I was hoping my output to look like this

+-----+----------+------------+
|Group|      Date|group_number|
+-----+----------+------------+
|    A|2018-01-03|           1|
|    A|2018-01-03|           1|
|    A|2018-01-03|           1|
|    B|2019-01-03|           2|
|    B|2019-01-03|           2|
|    B|2019-01-03|           2|
|    C|2020-01-03|           3|
|    C|2020-01-03|           3|
|    C|2020-01-03|           3|
+-----+----------+------------+

Any suggestions? I have found this post but this is just a binary solution! I have more than 2 groups in my dataset.

Upvotes: 1

Views: 329

Answers (2)

blackbishop
blackbishop

Reputation: 32650

What you want is to rank over all the groups not in each group so you don't need to partition by the Window, just order by Group and Date will give you the desired output:

windowSpec = Window.orderBy(F.asc("Group"), F.asc("Date"))

df.withColumn("group_number", F.dense_rank().over(windowSpec)).orderBy("Date").show()

#+-----+----------+------------+
#|Group|      Date|group_number|
#+-----+----------+------------+
#|    A|2018-01-03|           1|
#|    A|2018-01-03|           1|
#|    A|2018-01-03|           1|
#|    B|2019-01-03|           2|
#|    B|2019-01-03|           2|
#|    B|2019-01-03|           2|
#|    C|2020-01-03|           3|
#|    C|2020-01-03|           3|
#|    C|2020-01-03|           3|
#+-----+----------+------------+

And you surely don't need any UDF as the other answer suggests.

Upvotes: 1

Willy Chang
Willy Chang

Reputation: 51

You don't need to use the partitionBy function when you declare your windowSpec. By specifying the column "Group" in partionBy, you're telling the program to do a dense_rank() for each partition based on the "Date". So the output is correct. If we look at group A, they have the same dates, thus they all have a group_rank of 1. Moving on to group B, they all have the same dates, thus they have a group rank of 1.

So a quick fix for your problem is to remove the partionBy in your windowSpec.

EDIT: If you were to group by the Group column, the following is another solution: you can use a user defined function (UDF) as the second argument parameter in the df.withColumn(). In this UDF, you would specify your input/output like a normal function. Something like this:

import pyspark.sql.functions import udf

def new_column(group):
  return ord(group) - 64 # Unicode integer equivalent as A is 65

funct = udf(new_column, IntegerType())

df.withColumn("group_number", funct(df["Group"])).orderBy("Date").show()

If you were to use UDF for for the Date, you would need some way to keep track of Dates. An example:

import datetime

date_dict = {}
def new_column(date_obj):
   if len(date_dict) > 0 and date_dict[date_obj.strftime("%Y-%m-%d")]:
     return date_dict[date_obj.strftime("%Y-%m-%d")]
   date_dict[date_obj.strftime("%Y-%m-%d")] = len(date_obj.strftime("%Y-%m-%d")) + 1
   return date_dict[date_obj.strftime("%Y-%m-%d")]

Upvotes: 1

Related Questions