gnrtnltlt
gnrtnltlt

Reputation: 21

Cumulative Sum with Reset BEFORE Negative in Pyspark

Can someone help me with calculating the cumulative sum of these values where the cumulative is/resets to 0 before it reaches a negative number. I'm trying to implement this using pyspark

df =

+-----+

|value|

+-----+

| -1|
| 3|
| -2|
| 2|
| -1|
+-----+

expected result : +-----+-------+

|value|cum_sum|

+-----+-------+

| -1| 0|

| 3| 3|

| -2| 1|

| 2| 3|

| -1| 2|

+-----+-------+

Upvotes: 0

Views: 244

Answers (1)

Nithish
Nithish

Reputation: 3232

For Spark>=2.4.0, you can first apply collect_list to collect preeceding value as a list and then use aggregate to find the cumulative sum.

I am defining an ordering column for use with the Window function, it's advised to have this as spark does not guarantee ordering.

I don't define a partition in my code when using over a large number of rows, partition the dataset appropriately.

from pyspark.sql import functions as F
from pyspark.sql import Window as W

data = [(1, -1), (2, 3), (3, -2), (4, 2), (5, -1)]

df = spark.createDataFrame(data, ("ordering_column", "value"))

ws = W.orderBy("ordering_column").rowsBetween(W.unboundedPreceding, W.currentRow)

cum_sum_expr = F.expr("aggregate(cum_sum_values, 0L, (acc, x) -> greatest(acc + x, 0L))")

(df.withColumn("cum_sum_values", F.collect_list("value").over(ws))
  .withColumn("cum_sum", cum_sum_expr)
  .drop("cum_sum_values")
  .show())

"""
+---------------+-----+-------+
|ordering_column|value|cum_sum|
+---------------+-----+-------+
|              1|   -1|      0|
|              2|    3|      3|
|              3|   -2|      1|
|              4|    2|      3|
|              5|   -1|      2|
+---------------+-----+-------+
"""

Upvotes: 2

Related Questions