Reputation: 63
I have a spark dataframe that looks something like below.
date | ID | window_size | qty |
---|---|---|---|
01/01/2020 | 1 | 2 | 1 |
02/01/2020 | 1 | 2 | 2 |
03/01/2020 | 1 | 2 | 3 |
04/01/2020 | 1 | 2 | 4 |
01/01/2020 | 2 | 3 | 1 |
02/01/2020 | 2 | 3 | 2 |
03/01/2020 | 2 | 3 | 3 |
04/01/2020 | 2 | 3 | 4 |
I'm trying to apply a rolling window of size window_size to each ID in the dataframe and get the rolling sum. Basically I'm calculating a rolling sum (pd.groupby.rolling(window=n).sum()
in pandas) where the window size (n) can change per group.
Expected output
date | ID | window_size | qty | rolling_sum |
---|---|---|---|---|
01/01/2020 | 1 | 2 | 1 | null |
02/01/2020 | 1 | 2 | 2 | 3 |
03/01/2020 | 1 | 2 | 3 | 5 |
04/01/2020 | 1 | 2 | 4 | 7 |
01/01/2020 | 2 | 3 | 1 | null |
02/01/2020 | 2 | 3 | 2 | null |
03/01/2020 | 2 | 3 | 3 | 6 |
04/01/2020 | 2 | 3 | 4 | 9 |
I'm struggling to find a solution that works and is fast enough on a large dataframe (+- 350M rows).
What I have tried
I tried the solution in the below thread:
The idea is to first use sf.collect_list
and then slice the ArrayType
column correctly.
import pyspark.sql.types as st
import pyspark.sql.function as sf
window = Window.partitionBy('id').orderBy(params['date'])
output = (
sdf
.withColumn("qty_list", sf.collect_list('qty').over(window))
.withColumn("count", sf.count('qty').over(window))
.withColumn("rolling_sum", sf.when(sf.col('count') < sf.col('window_size'), None)
.otherwise(sf.slice('qty_list', sf.col('count'), sf.col('window_size'))))
).show()
However this yields below error:
TypeError: Column is not iterable
I have also tried using sf.expr
like below
window = Window.partitionBy('id').orderBy(params['date'])
output = (
sdf
.withColumn("qty_list", sf.collect_list('qty').over(window))
.withColumn("count", sf.count('qty').over(window))
.withColumn("rolling_sum", sf.when(sf.col('count') < sf.col('window_size'), None)
.otherwise(sf.expr("slice('window_size', 'count', 'window_size')")))
).show()
Which yields:
data type mismatch: argument 1 requires array type, however, ''qty_list'' is of string type.; line 1 pos 0;
I tried manually casting the qty_list
column to ArrayType(IntegerType())
with the same result.
I tried using a UDF but that fails with several out of memory errors after 1,5 hours or so.
Questions
Reading the spark documentation suggests to me that I should be able to pass columns to sf.slice()
, am I doing something wrong? Where is the TypeError
coming from?
Is there a better way to achieve what I want without using sf.collect_list()
and/or sf.slice()
?
If all else fails, what would be the optimal way to do this using a udf? I attempted different versions of the same udf and tried to make sure the udf is the last operation spark has to perform, but all failed.
Upvotes: 2
Views: 1235
Reputation: 32680
About the errors you get:
slice
using DataFrame API function (unless you have Spark 3.1+). But you already got it as you tried using it within SQL expression.expr
. It should be slice(qty_list, count, window_size)
otherwise Spark is considering them as strings hence the error message.That said, you almost got it, you need to change the expression for slicing to get the correct size of array, then use aggregate
function to sum up the values of the resulting array. Try with this:
from pyspark.sql import Window
import pyspark.sql.functions as F
w = Window.partitionBy('id').orderBy('date')
output = df.withColumn("qty_list", F.collect_list('qty').over(w)) \
.withColumn("rn", F.row_number().over(w)) \
.withColumn(
"qty_list",
F.when(
F.col('rn') < F.col('window_size'),
None
).otherwise(F.expr("slice(qty_list, rn-window_size+1, window_size)"))
).withColumn(
"rolling_sum",
F.expr("aggregate(qty_list, 0D, (acc, x) -> acc + x)").cast("int")
).drop("qty_list", "rn")
output.show()
#+----------+---+-----------+---+-----------+
#| date| ID|window_size|qty|rolling_sum|
#+----------+---+-----------+---+-----------+
#|01/01/2020| 1| 2| 1| null|
#|02/01/2020| 1| 2| 2| 3|
#|03/01/2020| 1| 2| 3| 5|
#|04/01/2020| 1| 2| 4| 7|
#|01/01/2020| 2| 3| 1| null|
#|02/01/2020| 2| 3| 2| null|
#|03/01/2020| 2| 3| 3| 6|
#|04/01/2020| 2| 3| 4| 9|
#+----------+---+-----------+---+-----------+
Upvotes: 2