antimuon
antimuon

Reputation: 262

Pyspark Rolling Average starting at first row

I am trying to calculate a rolling average in Pyspark. I have it working but it seems to have different behavior than what I expected. The rolling average starts at the first row.

For example:

columns = ['month', 'day', 'value']
data = [('JAN', '01', '20000'), ('JAN', '02', '40000'), ('JAN', '03', '30000'), ('JAN', '04', '25000'), ('JAN', '05', '5000'), ('JAN', '06', '15000'),
       ('FEB', '01', '10000'), ('FEB', '02', '50000'), ('FEB', '03', '100000'), ('FEB', '04', '60000'), ('FEB', '05', '1000'), ('FEB', '06', '10000'),]

df_test = sc.createDataFrame(data).toDF(*columns)
win = Window.partitionBy('month').orderBy('day').rowsBetween(-2,0)
df_test.withColumn('rolling_average', f.avg('value').over(win)).show()

+-----+---+------+------------------+
|month|day| value|   rolling_average|
+-----+---+------+------------------+
|  JAN| 01| 20000|           20000.0|
|  JAN| 02| 40000|           30000.0|
|  JAN| 03| 30000|           30000.0|
|  JAN| 04| 25000|31666.666666666668|
|  JAN| 05|  5000|           20000.0|
|  JAN| 06| 15000|           15000.0|
|  FEB| 01| 10000|           10000.0|
|  FEB| 02| 50000|           30000.0|
|  FEB| 03|100000|53333.333333333336|
|  FEB| 04| 60000|           70000.0|
|  FEB| 05|  1000|53666.666666666664|
|  FEB| 06| 10000|23666.666666666668|
+-----+---+------+------------------+

This would be more in line with what I expect. Is there way to get this behavior?

+-----+---+------+------------------+
|month|day| value|   rolling_average|
+-----+---+------+------------------+
|  JAN| 01| 20000|              null|
|  JAN| 02| 40000|              null|
|  JAN| 03| 30000|           30000.0|
|  JAN| 04| 25000|31666.666666666668|
|  JAN| 05|  5000|           20000.0|
|  JAN| 06| 15000|           15000.0|
|  FEB| 01| 10000|              null|
|  FEB| 02| 50000|              null|
|  FEB| 03|100000|53333.333333333336|
|  FEB| 04| 60000|           70000.0|
|  FEB| 05|  1000|53666.666666666664|
|  FEB| 06| 10000|23666.666666666668|
+-----+---+------+------------------+

The issue with the default behavior is that I need another column to keep track of where the lag should start from.

Upvotes: 1

Views: 405

Answers (2)

Lamanus
Lamanus

Reputation: 13541

More reduced version of @484.

import pyspark.sql.functions as f
from pyspark.sql import Window

w1 = Window.partitionBy('month').orderBy('day')
w2 = Window.partitionBy('month').orderBy('day').rowsBetween(-2, 0)

df.withColumn("rolling_average", f.when(f.row_number().over(w1) > f.lit(2), f.avg('value').over(w2))).show(10, False)

p.s. Please do not mark this as an answer :)

Upvotes: 1

notNull
notNull

Reputation: 31480

Try with row_number() window function then use when+otherwise statement to replace null.

  • To change the lag start then change when statement col("rn") <= <value> value.

Example:

columns = ['month', 'day', 'value']
data = [('JAN', '01', '20000'), ('JAN', '02', '40000'), ('JAN', '03', '30000'), ('JAN', '04', '25000'), ('JAN', '05', '5000'), ('JAN', '06', '15000'),
       ('FEB', '01', '10000'), ('FEB', '02', '50000'), ('FEB', '03', '100000'), ('FEB', '04', '60000'), ('FEB', '05', '1000'), ('FEB', '06', '10000'),]

df_test = sc.createDataFrame(data).toDF(*columns)
win = Window.partitionBy('month').orderBy('day').rowsBetween(-2,0)

win1 = Window.partitionBy('month').orderBy('day')

df_test.withColumn('rolling_average', f.avg('value').over(win)).\
withColumn("rn",row_number().over(win1)).\
withColumn("rolling_average",when(col("rn") <= 2 ,lit(None)).\
otherwise(col("rolling_average"))).\
drop("rn").\
show()
#+-----+---+------+------------------+
#|month|day| value|   rolling_average|
#+-----+---+------+------------------+
#|  FEB| 01| 10000|              null|
#|  FEB| 02| 50000|              null|
#|  FEB| 03|100000|53333.333333333336|
#|  FEB| 04| 60000|           70000.0|
#|  FEB| 05|  1000|53666.666666666664|
#|  FEB| 06| 10000|23666.666666666668|
#|  JAN| 01| 20000|              null|
#|  JAN| 02| 40000|              null|
#|  JAN| 03| 30000|           30000.0|
#|  JAN| 04| 25000|31666.666666666668|
#|  JAN| 05|  5000|           20000.0|
#|  JAN| 06| 15000|           15000.0|
#+-----+---+------+------------------+

Upvotes: 2

Related Questions