Reputation: 1235
I have the following pyspark dataframe
id timestamp col1
1 2022-01-01 0
1 2022-01-02 1
1 2022-01-03 1
1 2022-01-04 0
2 2022-01-01 1
2 2022-01-02 0
2 2022-01-03 1
I would like to get the cumulative sum of col1 for each ID and based on timestamp as an additional column and obtain something like this:
id timestamp col1 cum_sum
1 2022-01-01 0 0
1 2022-01-02 1 1
1 2022-01-03 1 2
1 2022-01-04 0 2
2 2022-01-01 1 1
2 2022-01-02 0 1
2 2022-01-03 1 2
Probably a Window Function can work here but I am not sure how to count only when col1 is equal to 1.
Upvotes: 1
Views: 489
Reputation: 3676
You indeed need a window function and a sum, the orderby on the window function is what makes it 'rolling'.
import pyspark.sql.functions as F
from pyspark.sql import Window
w = Window.partitionBy('id').orderBy('timestamp')
df.withColumn('cum_sum', F.sum('col1').over(w)).show()
+---+----------+----+-------+
| id| timestamp|col1|cum_sum|
+---+----------+----+-------+
| 1|2022-01-01| 0| 0|
| 1|2022-01-02| 1| 1|
| 1|2022-01-03| 1| 2|
| 1|2022-01-04| 0| 2|
| 2|2022-01-01| 1| 1|
| 2|2022-01-02| 0| 1|
| 2|2022-01-03| 1| 2|
+---+----------+----+-------+
Upvotes: 1