user12625679
user12625679

Reputation: 696

Dynamically calculate moving average based on last 3 day data [PySpark]

I would like to calculate the moving average for each customer_id and date based on the numbers of the last 3 days. Eg. to calculate the moving average for the 4th May we would need to compute the mean for the 1st-3rd May of purchase_sum

I thought of using some sort of window function but I am not quite sure how to calc. the mean based on the last 3 days for a given date and customer_id

Spark DF

date        customer_id   purchase_sum
2020-05-01  1             200
2020-05-02  1             243
2020-05-03  1             232
2020-05-04  1             253
2020-05-05  1             221
2020-05-06  1             212
2020-05-07  1             233

2020-05-01  2             323
2020-05-02  2             342
2020-05-03  2             342
2020-05-04  2             311
2020-05-05  2             344
2020-05-06  2             321
2020-05-07  2             345

Output Spark DF

date        customer_id   purchase_sum  L3D_moving_avg
2020-05-04  1             253           225
2020-05-05  1             221           243
2020-05-06  1             212           235
2020-05-07  1             233           228

2020-05-04  2             311           336
2020-05-05  2             344           332
2020-05-06  2             321           332
2020-05-07  2             345           325

Upvotes: 0

Views: 1112

Answers (1)

Cena
Cena

Reputation: 3419

Use rangeBetween(start, end) to create customized window frame boundaries inside your window function. The start and end are relative to the current row.

In your case, it should be rangeBetween(-3, -1). This looks back 3 days from the current date allowing you to compute the moving average.

from pyspark.sql.window import Window
from pyspark.sql import functions as F
from pyspark.sql.functions import col

w=Window().partitionBy("customer_id").orderBy("date")
df = df.withColumn('rank', F.dense_rank().over(w))

w2 = (Window().partitionBy("customer_id").orderBy("rank").rangeBetween(-3, -1))

df.select("*", (F.mean("purchase_sum").over(w2)).alias("L3D_moving_avg"))\
        .filter(col("rank")>=4).drop("rank").show()

+----------+-----------+------------+------------------+                        
|      date|customer_id|purchase_sum|    L3D_moving_avg|
+----------+-----------+------------+------------------+
|2020-05-04|          1|         253|             225.0|
|2020-05-05|          1|         221|242.66666666666666|
|2020-05-06|          1|         212|235.33333333333334|
|2020-05-07|          1|         233|228.66666666666666|
|2020-05-04|          2|         311| 335.6666666666667|
|2020-05-05|          2|         344| 331.6666666666667|
|2020-05-06|          2|         321| 332.3333333333333|
|2020-05-07|          2|         345| 325.3333333333333|
+----------+-----------+------------+------------------+

Upvotes: 3

Related Questions