Finn Murphy
Finn Murphy

Reputation: 3

Updating rows based on the next time a specific value occurs in a dataframe pyspark

If I have a dataframe like this

    data = [(("ID1", "ENGAGEMENT", 2019-03-03)), (("ID1", "BABY SHOWER", 2019-04-13)), (("ID1", "WEDDING", 2019-07-10)), 
           (("ID1", "DIVORCE", 2019-09-26))]
    df = spark.createDataFrame(data, ["ID", "Event", "start_date"])
    df.show()
    
    +---+-----------+----------+
    | ID|      Event|start_date|
    +---+-----------+----------+
    |ID1| ENGAGEMENT|2019-03-03|
    |ID1|BABY SHOWER|2019-04-13|
    |ID1|    WEDDING|2019-07-10|
    |ID1|    DIVORCE|2019-09-26|
    +---+-----------+----------+

From this dataframe the end date of the event would have to be inferred based on the start date of the subsequent events

For example: if you have a engagement then that would end when the wedding is so you would take the start date of the wedding as the end date of the engagement.

So the above dataframe should be getting this output.

+---+-----------+----------+----------+
| ID|      Event|start_date|  end_date|
+---+-----------+----------+----------+
|ID1| ENGAGEMENT|2019-03-03|2019-07-10|
|ID1|BABY SHOWER|2019-04-13|2019-04-13|
|ID1|    WEDDING|2019-07-10|2019-09-26|
|ID1|    DIVORCE|2019-09-26|      NULL|
+---+-----------+----------+----------+

I initially attempted this using the lead function over a window partioned by the ID to get rows in front but as it can be 20 rows later that the "Wedding" event would be it doesn't work and is a really messy way to do it.

df = df.select("*", *([f.lead(f.col(c),default=None).over(Window.orderBy("ID")).alias("LEAD_"+c) 
                      for c in ["Event", "start_date"]]))

activity_dates = activity_dates.select("*", *([f.lead(f.col(c),default=None).over(Window.orderBy("ID")).alias("LEAD_"+c) 
                      for c in ["LEAD_Event", "LEAD_start_date"]]))


df = df.withColumn("end_date", f.when((col("Event") == "ENGAGEMENT") & (col("LEAD_Event") == "WEDDING"), col("LEAD_start_date"))
                                .when((col("Event") == "ENGAGEMENT") & (col("LEAD_LEAD_Event") == "WEDDING"), col("LEAD_LEAD_start_date"))

How can I achieve this without looping through the dataset?

Upvotes: 0

Views: 68

Answers (1)

Lamanus
Lamanus

Reputation: 13591

Here is my try.

from pyspark.sql import Window
from pyspark.sql.functions import *

df.withColumn('end_date', expr('''
    case when Event = 'ENGAGEMENT'  then first(if(Event = 'WEDDING', start_date, null), True) over (Partition By ID)
         when Event = 'BABY SHOWER' then first(if(Event = 'BABY SHOWER', start_date, null), True) over (Partition By ID)
         when Event = 'WEDDING'     then first(if(Event = 'DIVORCE', start_date, null), True) over (Partition By ID)
    else null end
''')).show()

+---+-----------+----------+----------+
| ID|      Event|start_date|  end_date|
+---+-----------+----------+----------+
|ID1| ENGAGEMENT|2019-03-03|2019-07-10|
|ID1|BABY SHOWER|2019-04-13|2019-04-13|
|ID1|    WEDDING|2019-07-10|2019-09-26|
|ID1|    DIVORCE|2019-09-26|      null|
+---+-----------+----------+----------+

Upvotes: 1

Related Questions