Reputation: 1475
I have a dataset in a pyspark job that looks a bit like this:
frame_id direction_change
1 False
2 False
3 False
4 True
5 False
I want to add a "track" counter to each row so that all the frames between direction changes have the same value. For example, the output I want looks like this:
frame_id direction_change track
1 False 1
2 False 1
3 False 1
4 True 2
5 False 2
I have been able to do this with Pandas with the following action:
frames['track'] = frames['direction_change'].cumsum()
But can't find an equivalent way to do it in Spark data frames. Any help would be really appreciated.
Upvotes: 0
Views: 901
Reputation: 330173
Long story short there is no efficient way to do in PySpark with DataFrames
alone. One could be tempted to use window functions like this:
from pyspark.sql.functions import col, sum as sum_
from pyspark.sql.window import Window
w = Window().orderBy("frame_id")
df.withColumn("change", 1 + sum_(col("direction_change").cast("long")).over(w))
but this inefficient and won't scale. It is possible to use lower level APIs as show in How to compute cumulative sum using Spark but in Python it requires moving out of Dataset
/ Dataframe
API and using plain RDDs.
Upvotes: 2