Rocco
Rocco

Reputation: 1475

PySpark Data Frame - give an ID to sequence of same values

I have a dataset in a pyspark job that looks a bit like this:

frame_id    direction_change  
1           False  
2           False  
3           False  
4           True  
5           False  

I want to add a "track" counter to each row so that all the frames between direction changes have the same value. For example, the output I want looks like this:

frame_id    direction_change    track
1           False               1
2           False               1
3           False               1
4           True                2
5           False               2  

I have been able to do this with Pandas with the following action:

frames['track'] = frames['direction_change'].cumsum()

But can't find an equivalent way to do it in Spark data frames. Any help would be really appreciated.

Upvotes: 0

Views: 901

Answers (1)

zero323
zero323

Reputation: 330173

Long story short there is no efficient way to do in PySpark with DataFrames alone. One could be tempted to use window functions like this:

from pyspark.sql.functions import col, sum as sum_
from pyspark.sql.window import Window

w = Window().orderBy("frame_id")

df.withColumn("change", 1 + sum_(col("direction_change").cast("long")).over(w))

but this inefficient and won't scale. It is possible to use lower level APIs as show in How to compute cumulative sum using Spark but in Python it requires moving out of Dataset / Dataframe API and using plain RDDs.

Upvotes: 2

Related Questions