user3448011
user3448011

Reputation: 1599

how to design a customized window function to select a column values within a time window of pyspark dataframe

This question is relevant to my previous question. pyspark dataframe aggregate a column by sliding time window

But, I would like to create a post in order to clarify some key points that are missing in my previous question.

the original dataframe:

client_id    value1    name1    a_date
 dhd         561       ecdu     2019-10-8
 dhd         561       tygp     2019-10-8  
 dhd         561       rdsr     2019-10-8
 dhd         561       rgvd     2019-8-12
 dhd         561       bhnd     2019-8-12
 dhd         561       prti     2019-8-12
 dhd         561       teuq     2019-5-7
 dhd         561       wnva     2019-5-7
 dhd         561       pqhn     2019-5-7

I need to find the values of "name1" for each "client_id", each "value1" and for some given sliding time window.

I defined a window function:

 w = window().partitionBy("client_id", "value1").orderBy("a_date")

But I do not know how to select values of "name1" for the window size of 1, 2, 6, 9, 12.

Here, window size means that the length of month from the current month of "a_date".

e.g.

 client_id     value1    names1_within_window_size_1  names1_within_window_size_2
  dhd           561       [ecdu,tygp,rdsr]             [ecdu,tygp,rdsr]   

  names1_within_window_size_6
  [ecdu,tygp,rdsr, rgvd,bhnd,prti, teuq, wnva,pqhn ]  

 names1_within_window_size_1   : the month window 2019-10
 names1_within_window_size_2    : the month window 2019-10 and 2019-9 (no data in 2019-9 so just keep the data from 2019-10)
 names1_within_window_size_6    : the month window 2019-10 and 2019-9 (no data in 2019-9 so just keep the data from 2019-10) but there are data in 2019-8

Thanks

============================================ UPDATE

from pyspark.sql import functions as F
from pyspark.sql.window import Window

data=  [['dhd',589,'ecdu','2020-1-5'],
    ['dhd',575,'tygp','2020-1-5'],  
    ['dhd',821,'rdsr','2020-1-5'],
    ['dhd',872,'rgvd','2019-12-1'],
    ['dhd',619,'bhnd','2019-12-15'],
    ['dhd',781,'prti','2019-12-18'],
    ['dhd',781,'prti1','2019-12-18'],
    ['dhd',781,'prti2','2019-11-18'],
    ['dhd',781,'prti3','2019-10-31'],
    ['dhd',781,'prti4','2019-09-30'],
    ['dhd',781,'prt1','2019-07-31'],
    ['dhd',781,'pr4','2019-06-30'],
    ['dhd',781,'pr2','2019-08-31'],
    ['dhd',781,'prt4','2019-01-31'],
    ['dhd',781,'prti6','2019-02-28'],
    ['dhd',781,'prti7','2019-02-02'],
    ['dhd',781,'prti8','2019-03-29'],
    ['dhd',781,'prti9','2019-04-29'],
    ['dhd',781,'prti10','2019-05-04'],
    ['dhd',781,'prti11','2019-03-01'],
    ['dhd',781,'prti12','2018-12-17'],
    ['dhd',781,'prti15','2018-11-21'],
    ['dhd',781,'prti17','2018-10-31'],
    ['dhd',781,'prti19','2018-09-5']

   ]
columns= ['client_id','value1','name1','a_date']

df= spark.createDataFrame(data,columns)

df2 = df.withColumn("year_val", F.year("a_date"))\
    .withColumn("month_val", F.month("a_date"))\
    .withColumn("year_month", F.year(F.col("a_date")) * 100 + 
    F.month(F.col("a_date")))\
    .groupBy("client_id", "value1", "year_month")\
    .agg(F.concat_ws(", ", F.collect_list("name1")).alias("init_list"))

 df2.sort(F.col("value1").desc(), F.col("year_month").desc()).show()

 w = Window().partitionBy("client_id", "value1")\
    .orderBy("year_month")
df4 = df2.withColumn("a_rank", F.dense_rank().over(w))
df4.sort(F.col("value1"), F.col("year_month")).show()


month_range = 3
w = Window().partitionBy("client_id", "value1")\
    .orderBy("a_rank")\
    .rangeBetween(-(month_range),0)

 df5 = df4.withColumn("last_" + str(month_range) + "_month", F.collect_list(F.col("init_list")).over(w))\
    .orderBy("value1", "a_rank")

 df6 = df5.sort(F.col("value1").desc(), F.col("year_month").desc())
 df6.show(100,False)

Upvotes: 0

Views: 48

Answers (1)

xenodevil
xenodevil

Reputation: 604

I stole data from your previous question for this as I was too lazy to do it myself and some nice guy had crafted the list for input data over there.

As window slides over number of records, rather than number of months, I combined all records for a give month (grouped by client_id and value1, of course) in a single record in .groupBy("client_id", "value1", "year_val", "month_val") which is present in computation for df2

from pyspark.sql import functions as F
from pyspark.sql.window import Window

data=  [['dhd',589,'ecdu','2020-1-5'],
        ['dhd',575,'tygp','2020-1-5'],  
        ['dhd',821,'rdsr','2020-1-5'],
        ['dhd',872,'rgvd','2019-12-1'],
        ['dhd',619,'bhnd','2019-12-15'],
        ['dhd',781,'prti','2019-12-18'],
        ['dhd',781,'prti1','2019-12-18'],
        ['dhd',781,'prti2','2019-11-18'],
        ['dhd',781,'prti3','2019-10-31'],
        ['dhd',781,'prti4','2019-09-30'],
        ['dhd',781,'prt1','2019-07-31'],
        ['dhd',781,'pr4','2019-06-30'],
        ['dhd',781,'pr2','2019-08-31'],
        ['dhd',781,'prt4','2019-01-31'],
        ['dhd',781,'prti6','2019-02-28'],
        ['dhd',781,'prti7','2019-02-02'],
        ['dhd',781,'prti8','2019-03-29'],
        ['dhd',781,'prti9','2019-04-29'],
        ['dhd',781,'prti10','2019-05-04'],
        ['dhd',781,'prti11','2019-03-01']]
columns= ['client_id','value1','name1','a_date']

df= spark.createDataFrame(data,columns)

df2 = df.withColumn("year_val", F.year("a_date"))\
        .withColumn("month_val", F.month("a_date"))\
        .groupBy("client_id", "value1", "year_val", "month_val")\
        .agg(F.concat_ws(", ", F.collect_list("name1")).alias("init_list"))

df2.show()

Here, we get init_list as:

+---------+------+--------+---------+-------------+
|client_id|value1|year_val|month_val|    init_list|
+---------+------+--------+---------+-------------+
|      dhd|   781|    2019|       12|  prti, prti1|
|      dhd|   589|    2020|        1|         ecdu|
|      dhd|   781|    2019|        8|          pr2|
|      dhd|   781|    2019|        3|prti8, prti11|
|      dhd|   575|    2020|        1|         tygp|
|      dhd|   781|    2019|        5|       prti10|
|      dhd|   781|    2019|        9|        prti4|
|      dhd|   781|    2019|       11|        prti2|
|      dhd|   781|    2019|       10|        prti3|
|      dhd|   821|    2020|        1|         rdsr|
|      dhd|   781|    2019|        6|          pr4|
|      dhd|   619|    2019|       12|         bhnd|
|      dhd|   781|    2019|        7|         prt1|
|      dhd|   781|    2019|        4|        prti9|
|      dhd|   781|    2019|        1|         prt4|
|      dhd|   781|    2019|        2| prti6, prti7|
|      dhd|   872|    2019|       12|         rgvd|
+---------+------+--------+---------+-------------+

Using this, we can get the final result by simply running the window over the records:

month_range = 6
w = Window().partitionBy("client_id", "value1")\
        .orderBy("month_val")\
        .rangeBetween(-(month_range+1),0)

df3 = df2.withColumn("last_0_month", F.collect_list(F.col("init_list")).over(w))\
        .orderBy("value1", "year_val", "month_val")

df3.show(100,False)

Which gives us:

+---------+------+--------+---------+-------------+-------------------------------------------------------------------+
|client_id|value1|year_val|month_val|init_list    |last_0_month                                                       |
+---------+------+--------+---------+-------------+-------------------------------------------------------------------+
|dhd      |575   |2020    |1        |tygp         |[tygp]                                                             |
|dhd      |589   |2020    |1        |ecdu         |[ecdu]                                                             |
|dhd      |619   |2019    |12       |bhnd         |[bhnd]                                                             |
|dhd      |781   |2019    |1        |prt4         |[prt4]                                                             |
|dhd      |781   |2019    |2        |prti6, prti7 |[prt4, prti6, prti7]                                               |
|dhd      |781   |2019    |3        |prti8, prti11|[prt4, prti6, prti7, prti8, prti11]                                |
|dhd      |781   |2019    |4        |prti9        |[prt4, prti6, prti7, prti8, prti11, prti9]                         |
|dhd      |781   |2019    |5        |prti10       |[prt4, prti6, prti7, prti8, prti11, prti9, prti10]                 |
|dhd      |781   |2019    |6        |pr4          |[prt4, prti6, prti7, prti8, prti11, prti9, prti10, pr4]            |
|dhd      |781   |2019    |7        |prt1         |[prt4, prti6, prti7, prti8, prti11, prti9, prti10, pr4, prt1]      |
|dhd      |781   |2019    |8        |pr2          |[prt4, prti6, prti7, prti8, prti11, prti9, prti10, pr4, prt1, pr2] |
|dhd      |781   |2019    |9        |prti4        |[prti6, prti7, prti8, prti11, prti9, prti10, pr4, prt1, pr2, prti4]|
|dhd      |781   |2019    |10       |prti3        |[prti8, prti11, prti9, prti10, pr4, prt1, pr2, prti4, prti3]       |
|dhd      |781   |2019    |11       |prti2        |[prti9, prti10, pr4, prt1, pr2, prti4, prti3, prti2]               |
|dhd      |781   |2019    |12       |prti, prti1  |[prti10, pr4, prt1, pr2, prti4, prti3, prti2, prti, prti1]         |
|dhd      |821   |2020    |1        |rdsr         |[rdsr]                                                             |
|dhd      |872   |2019    |12       |rgvd         |[rgvd]                                                             |
+---------+------+--------+---------+-------------+-------------------------------------------------------------------+

Limitations:

Sadly, by the second portion, the a_date field is lost and for sliding window operations with range defined over them, the orderBy cannot specify multiple columns (note that orderBy in the window definition is only on month_val). For this reason, this exact solution will not work for data spanning multiple years. However, this can be easily overcome by having something like a month_id as a single column combining year and month values and then using it in the orderBy clause.

If you want to have multiple windows, you can convert month_range to a list and loop over it in the last code snippet to cover all ranges.

Although the last column (last_0_month) looks like an array, it contains comma separated strings from previous agg operation. You may want to clean it up as well.

Upvotes: 1

Related Questions