Reputation: 8854
How do I calculate rolling median of dollar for a window size of previous 3 values?
Input data
dollars timestampGMT
25 2017-03-18 11:27:18
17 2017-03-18 11:27:19
13 2017-03-18 11:27:20
27 2017-03-18 11:27:21
13 2017-03-18 11:27:22
43 2017-03-18 11:27:23
12 2017-03-18 11:27:24
Expected Output data
dollars timestampGMT rolling_median_dollar
25 2017-03-18 11:27:18 median(25)
17 2017-03-18 11:27:19 median(17,25)
13 2017-03-18 11:27:20 median(13,17,25)
27 2017-03-18 11:27:21 median(27,13,17)
13 2017-03-18 11:27:22 median(13,27,13)
43 2017-03-18 11:27:23 median(43,13,27)
12 2017-03-18 11:27:24 median(12,43,13)
Below code does moving avg but PySpark doesn't have F.median().
pyspark: rolling average using timeseries data
EDIT 1: The challenge is median() function doesn't exit. I cannot do
df = df.withColumn('rolling_average', F.median("dollars").over(w))
If I wanted moving average I could have done
df = df.withColumn('rolling_average', F.avg("dollars").over(w))
EDIT 2: Tried using approxQuantile()
windfun = Window().partitionBy().orderBy(F.col(date_column)).rowsBetween(-3, 0) sdf.withColumn("movingMedian", sdf.approxQuantile(col='a', probabilities=[0.5], relativeError=0.00001).over(windfun))
But getting error
AttributeError: 'list' object has no attribute 'over'
EDIT 3
Please give solution without Udf since it won't benefit from catalyst optimization.
Upvotes: 11
Views: 11196
Reputation: 71
Another way without using any udf is to use the expr
from the pyspark.sql.functions
dict = [{'dollars': 25,'timestampGMT': '2017-03-18 11:27:18'},
{'dollars': 17,'timestampGMT': '2017-03-18 11:27:19'},
{'dollars': 13,'timestampGMT': '2017-03-18 11:27:20'},
{'dollars': 27,'timestampGMT': '2017-03-18 11:27:21'},
{'dollars': 13,'timestampGMT': '2017-03-18 11:27:22'},
{'dollars': 43,'timestampGMT': '2017-03-18 11:27:23'},
{'dollars': 12,'timestampGMT': '2017-03-18 11:27:24'}
]
test = spark.createDataFrame(dict,schema=['dollars','timestampGMT'])
test.withColumn("id", F.lit(1)).withColumn(
"rolling_median_dollar",
F.expr("percentile(dollars,0.5)").over(
W.partitionBy("id")
.orderBy(F.col("timestampGMT").cast("long"))
.rowsBetween(-2, 0)
),
).drop('id').show()
+-------+-------------------+---------------------+
|dollars| timestampGMT|rolling_median_dollar|
+-------+-------------------+---------------------+
| 25|2017-03-18 11:27:18| 25.0|
| 17|2017-03-18 11:27:19| 21.0|
| 13|2017-03-18 11:27:20| 17.0|
| 27|2017-03-18 11:27:21| 17.0|
| 13|2017-03-18 11:27:22| 13.0|
| 43|2017-03-18 11:27:23| 27.0|
| 12|2017-03-18 11:27:24| 13.0|
+-------+-------------------+---------------------+
Upvotes: 3
Reputation: 24198
One way is to collect the $dollars
column as a list per window, and then calculate the median of the resulting lists using an udf
:
from pyspark.sql.window import Window
from pyspark.sql.functions import *
import numpy as np
from pyspark.sql.types import FloatType
w = (Window.orderBy(col("timestampGMT").cast('long')).rangeBetween(-2, 0))
median_udf = udf(lambda x: float(np.median(x)), FloatType())
df.withColumn("list", collect_list("dollars").over(w)) \
.withColumn("rolling_median", median_udf("list")).show(truncate = False)
+-------+---------------------+------------+--------------+
|dollars|timestampGMT |list |rolling_median|
+-------+---------------------+------------+--------------+
|25 |2017-03-18 11:27:18.0|[25] |25.0 |
|17 |2017-03-18 11:27:19.0|[25, 17] |21.0 |
|13 |2017-03-18 11:27:20.0|[25, 17, 13]|17.0 |
|27 |2017-03-18 11:27:21.0|[17, 13, 27]|17.0 |
|13 |2017-03-18 11:27:22.0|[13, 27, 13]|13.0 |
|43 |2017-03-18 11:27:23.0|[27, 13, 43]|27.0 |
|12 |2017-03-18 11:27:24.0|[13, 43, 12]|13.0 |
+-------+---------------------+------------+--------------+
Upvotes: 12