Marco
Marco

Reputation: 1235

Pyspark Rolling Sum based on ID, timestamp and condition

I have the following pyspark dataframe

id  timestamp   col1
1   2022-01-01   0
1   2022-01-02   1
1   2022-01-03   1
1   2022-01-04   0
2   2022-01-01   1
2   2022-01-02   0
2   2022-01-03   1

I would like to get the cumulative sum of col1 for each ID and based on timestamp as an additional column and obtain something like this:

id  timestamp   col1  cum_sum
1   2022-01-01   0      0
1   2022-01-02   1      1
1   2022-01-03   1      2
1   2022-01-04   0      2
2   2022-01-01   1      1
2   2022-01-02   0      1
2   2022-01-03   1      2

Probably a Window Function can work here but I am not sure how to count only when col1 is equal to 1.

Upvotes: 1

Views: 489

Answers (1)

ScootCork
ScootCork

Reputation: 3676

You indeed need a window function and a sum, the orderby on the window function is what makes it 'rolling'.

import pyspark.sql.functions as F
from pyspark.sql import Window

w = Window.partitionBy('id').orderBy('timestamp')
df.withColumn('cum_sum', F.sum('col1').over(w)).show()

+---+----------+----+-------+
| id| timestamp|col1|cum_sum|
+---+----------+----+-------+
|  1|2022-01-01|   0|      0|
|  1|2022-01-02|   1|      1|
|  1|2022-01-03|   1|      2|
|  1|2022-01-04|   0|      2|
|  2|2022-01-01|   1|      1|
|  2|2022-01-02|   0|      1|
|  2|2022-01-03|   1|      2|
+---+----------+----+-------+

Upvotes: 1

Related Questions