Joep Atol
Joep Atol

Reputation: 63

PySpark: applying varying window sizes to a dataframe in pyspark

I have a spark dataframe that looks something like below.

date ID window_size qty
01/01/2020 1 2 1
02/01/2020 1 2 2
03/01/2020 1 2 3
04/01/2020 1 2 4
01/01/2020 2 3 1
02/01/2020 2 3 2
03/01/2020 2 3 3
04/01/2020 2 3 4

I'm trying to apply a rolling window of size window_size to each ID in the dataframe and get the rolling sum. Basically I'm calculating a rolling sum (pd.groupby.rolling(window=n).sum() in pandas) where the window size (n) can change per group.

Expected output

date ID window_size qty rolling_sum
01/01/2020 1 2 1 null
02/01/2020 1 2 2 3
03/01/2020 1 2 3 5
04/01/2020 1 2 4 7
01/01/2020 2 3 1 null
02/01/2020 2 3 2 null
03/01/2020 2 3 3 6
04/01/2020 2 3 4 9

I'm struggling to find a solution that works and is fast enough on a large dataframe (+- 350M rows).

What I have tried

I tried the solution in the below thread:

The idea is to first use sf.collect_list and then slice the ArrayType column correctly.

import pyspark.sql.types as st
import pyspark.sql.function as sf

window = Window.partitionBy('id').orderBy(params['date'])

output = (
    sdf
    .withColumn("qty_list", sf.collect_list('qty').over(window))
    .withColumn("count", sf.count('qty').over(window))
    .withColumn("rolling_sum", sf.when(sf.col('count') < sf.col('window_size'), None)
                                 .otherwise(sf.slice('qty_list', sf.col('count'), sf.col('window_size'))))
).show()

However this yields below error:

TypeError: Column is not iterable

I have also tried using sf.expr like below

window = Window.partitionBy('id').orderBy(params['date'])

output = (
    sdf
    .withColumn("qty_list", sf.collect_list('qty').over(window))
    .withColumn("count", sf.count('qty').over(window))
    .withColumn("rolling_sum", sf.when(sf.col('count') < sf.col('window_size'), None)
                                 .otherwise(sf.expr("slice('window_size', 'count', 'window_size')")))
).show()

Which yields:

data type mismatch: argument 1 requires array type, however, ''qty_list'' is of string type.; line 1 pos 0;

I tried manually casting the qty_list column to ArrayType(IntegerType()) with the same result.

I tried using a UDF but that fails with several out of memory errors after 1,5 hours or so.

Questions

  1. Reading the spark documentation suggests to me that I should be able to pass columns to sf.slice(), am I doing something wrong? Where is the TypeError coming from?

  2. Is there a better way to achieve what I want without using sf.collect_list() and/or sf.slice()?

  3. If all else fails, what would be the optimal way to do this using a udf? I attempted different versions of the same udf and tried to make sure the udf is the last operation spark has to perform, but all failed.

Upvotes: 2

Views: 1235

Answers (1)

blackbishop
blackbishop

Reputation: 32680

About the errors you get:

  1. The first one means you can't pass a column to slice using DataFrame API function (unless you have Spark 3.1+). But you already got it as you tried using it within SQL expression.
  2. Second error occurs because you pass column names quoted in your expr. It should be slice(qty_list, count, window_size) otherwise Spark is considering them as strings hence the error message.

That said, you almost got it, you need to change the expression for slicing to get the correct size of array, then use aggregate function to sum up the values of the resulting array. Try with this:

from pyspark.sql import Window
import pyspark.sql.functions as F

w = Window.partitionBy('id').orderBy('date')

output = df.withColumn("qty_list", F.collect_list('qty').over(w)) \
    .withColumn("rn", F.row_number().over(w)) \
    .withColumn(
        "qty_list",
        F.when(
            F.col('rn') < F.col('window_size'),
            None
        ).otherwise(F.expr("slice(qty_list, rn-window_size+1, window_size)"))
    ).withColumn(
        "rolling_sum",
        F.expr("aggregate(qty_list, 0D, (acc, x) -> acc + x)").cast("int")
    ).drop("qty_list", "rn")

output.show()
#+----------+---+-----------+---+-----------+
#|      date| ID|window_size|qty|rolling_sum|
#+----------+---+-----------+---+-----------+
#|01/01/2020|  1|          2|  1|       null|
#|02/01/2020|  1|          2|  2|          3|
#|03/01/2020|  1|          2|  3|          5|
#|04/01/2020|  1|          2|  4|          7|
#|01/01/2020|  2|          3|  1|       null|
#|02/01/2020|  2|          3|  2|       null|
#|03/01/2020|  2|          3|  3|          6|
#|04/01/2020|  2|          3|  4|          9|
#+----------+---+-----------+---+-----------+

Upvotes: 2

Related Questions