Reputation: 262
I am trying to calculate a rolling average in Pyspark. I have it working but it seems to have different behavior than what I expected. The rolling average starts at the first row.
For example:
columns = ['month', 'day', 'value']
data = [('JAN', '01', '20000'), ('JAN', '02', '40000'), ('JAN', '03', '30000'), ('JAN', '04', '25000'), ('JAN', '05', '5000'), ('JAN', '06', '15000'),
('FEB', '01', '10000'), ('FEB', '02', '50000'), ('FEB', '03', '100000'), ('FEB', '04', '60000'), ('FEB', '05', '1000'), ('FEB', '06', '10000'),]
df_test = sc.createDataFrame(data).toDF(*columns)
win = Window.partitionBy('month').orderBy('day').rowsBetween(-2,0)
df_test.withColumn('rolling_average', f.avg('value').over(win)).show()
+-----+---+------+------------------+
|month|day| value| rolling_average|
+-----+---+------+------------------+
| JAN| 01| 20000| 20000.0|
| JAN| 02| 40000| 30000.0|
| JAN| 03| 30000| 30000.0|
| JAN| 04| 25000|31666.666666666668|
| JAN| 05| 5000| 20000.0|
| JAN| 06| 15000| 15000.0|
| FEB| 01| 10000| 10000.0|
| FEB| 02| 50000| 30000.0|
| FEB| 03|100000|53333.333333333336|
| FEB| 04| 60000| 70000.0|
| FEB| 05| 1000|53666.666666666664|
| FEB| 06| 10000|23666.666666666668|
+-----+---+------+------------------+
This would be more in line with what I expect. Is there way to get this behavior?
+-----+---+------+------------------+
|month|day| value| rolling_average|
+-----+---+------+------------------+
| JAN| 01| 20000| null|
| JAN| 02| 40000| null|
| JAN| 03| 30000| 30000.0|
| JAN| 04| 25000|31666.666666666668|
| JAN| 05| 5000| 20000.0|
| JAN| 06| 15000| 15000.0|
| FEB| 01| 10000| null|
| FEB| 02| 50000| null|
| FEB| 03|100000|53333.333333333336|
| FEB| 04| 60000| 70000.0|
| FEB| 05| 1000|53666.666666666664|
| FEB| 06| 10000|23666.666666666668|
+-----+---+------+------------------+
The issue with the default behavior is that I need another column to keep track of where the lag should start from.
Upvotes: 1
Views: 405
Reputation: 13541
More reduced version of @484.
import pyspark.sql.functions as f
from pyspark.sql import Window
w1 = Window.partitionBy('month').orderBy('day')
w2 = Window.partitionBy('month').orderBy('day').rowsBetween(-2, 0)
df.withColumn("rolling_average", f.when(f.row_number().over(w1) > f.lit(2), f.avg('value').over(w2))).show(10, False)
p.s. Please do not mark this as an answer :)
Upvotes: 1
Reputation: 31480
Try with row_number()
window function then use when+otherwise statement to replace null.
lag start
then change when
statement col("rn") <= <value>
value.Example:
columns = ['month', 'day', 'value']
data = [('JAN', '01', '20000'), ('JAN', '02', '40000'), ('JAN', '03', '30000'), ('JAN', '04', '25000'), ('JAN', '05', '5000'), ('JAN', '06', '15000'),
('FEB', '01', '10000'), ('FEB', '02', '50000'), ('FEB', '03', '100000'), ('FEB', '04', '60000'), ('FEB', '05', '1000'), ('FEB', '06', '10000'),]
df_test = sc.createDataFrame(data).toDF(*columns)
win = Window.partitionBy('month').orderBy('day').rowsBetween(-2,0)
win1 = Window.partitionBy('month').orderBy('day')
df_test.withColumn('rolling_average', f.avg('value').over(win)).\
withColumn("rn",row_number().over(win1)).\
withColumn("rolling_average",when(col("rn") <= 2 ,lit(None)).\
otherwise(col("rolling_average"))).\
drop("rn").\
show()
#+-----+---+------+------------------+
#|month|day| value| rolling_average|
#+-----+---+------+------------------+
#| FEB| 01| 10000| null|
#| FEB| 02| 50000| null|
#| FEB| 03|100000|53333.333333333336|
#| FEB| 04| 60000| 70000.0|
#| FEB| 05| 1000|53666.666666666664|
#| FEB| 06| 10000|23666.666666666668|
#| JAN| 01| 20000| null|
#| JAN| 02| 40000| null|
#| JAN| 03| 30000| 30000.0|
#| JAN| 04| 25000|31666.666666666668|
#| JAN| 05| 5000| 20000.0|
#| JAN| 06| 15000| 15000.0|
#+-----+---+------+------------------+
Upvotes: 2