Reputation: 21047
I have a Spark (2.4.0) data frame with a column that has just two values (either 0
or 1
). I need to calculate the streak of consecutive 0
s and 1
s in this data, resetting the streak to zero if the value changes.
An example:
from pyspark.sql import (SparkSession, Window)
from pyspark.sql.functions import (to_date, row_number, lead, col)
spark = SparkSession.builder.appName('test').getOrCreate()
# Create dataframe
df = spark.createDataFrame([
('2018-01-01', 'John', 0, 0),
('2018-01-01', 'Paul', 1, 0),
('2018-01-08', 'Paul', 3, 1),
('2018-01-08', 'Pete', 4, 0),
('2018-01-08', 'John', 3, 0),
('2018-01-15', 'Mary', 6, 0),
('2018-01-15', 'Pete', 6, 0),
('2018-01-15', 'John', 6, 1),
('2018-01-15', 'Paul', 6, 1),
], ['str_date', 'name', 'value', 'flag'])
df.orderBy('name', 'str_date').show()
## +----------+----+-----+----+
## | str_date|name|value|flag|
## +----------+----+-----+----+
## |2018-01-01|John| 0| 0|
## |2018-01-08|John| 3| 0|
## |2018-01-15|John| 6| 1|
## |2018-01-15|Mary| 6| 0|
## |2018-01-01|Paul| 1| 0|
## |2018-01-08|Paul| 3| 1|
## |2018-01-15|Paul| 6| 1|
## |2018-01-08|Pete| 4| 0|
## |2018-01-15|Pete| 6| 0|
## +----------+----+-----+----+
With this data, I'd like to calculate the streak of consecutive zeros and ones, ordered by date and "windowed" by name:
# Expected result:
## +----------+----+-----+----+--------+--------+
## | str_date|name|value|flag|streak_0|streak_1|
## +----------+----+-----+----+--------+--------+
## |2018-01-01|John| 0| 0| 1| 0|
## |2018-01-08|John| 3| 0| 2| 0|
## |2018-01-15|John| 6| 1| 0| 1|
## |2018-01-15|Mary| 6| 0| 1| 0|
## |2018-01-01|Paul| 1| 0| 1| 0|
## |2018-01-08|Paul| 3| 1| 0| 1|
## |2018-01-15|Paul| 6| 1| 0| 2|
## |2018-01-08|Pete| 4| 0| 1| 0|
## |2018-01-15|Pete| 6| 0| 2| 0|
## +----------+----+-----+----+--------+--------+
Of course, I would need the streak to reset itself to zero if the 'flag' changes.
Is there a way of doing this?
Upvotes: 4
Views: 4441
Reputation: 3417
There is a more intuitive solution without the use of row_number()
if you already have a natural ordering column (str_date
) in this case.
In short, to find streak of 1's, just use the
To find streak of 0's, invert the flag first and then do the same for streak of 1's.
First we define a function to calculate cumulative sum:
from pyspark.sql import Window
from pyspark.sql import functions as f
def cum_sum(df, new_col_name, partition_cols, order_col, value_col):
windowval = (Window.partitionBy(partition_cols).orderBy(order_col)
.rowsBetween(Window.unboundedPreceding, 0))
return df.withColumn(new_col_name, f.sum(value_col).over(windowval))
Note the use of rowsBetween
(instead of rangeBetween
). This is important to get the correct cumulative sum when there are duplicate values in the order column.
df = cum_sum(df,
new_col_name='1_group',
partition_cols='name',
order_col='str_date',
value_col='flag')
df = df.withColumn('streak_1', f.col('flag')*f.col('1_group'))
df = df.withColumn('flag_inverted', 1-f.col('flag'))
df = cum_sum(df,
new_col_name='0_group',
partition_cols='name',
order_col='str_date',
value_col='flag_inverted')
df = df.withColumn('streak_0', f.col('flag_inverted')*f.col('0_group'))
Upvotes: 2
Reputation: 49260
This would require a difference in row numbers approach to first group consecutive rows with the same value and then using a ranking approach among the groups.
from pyspark.sql import Window
from pyspark.sql import functions as f
#Windows definition
w1 = Window.partitionBy(df.name).orderBy(df.date)
w2 = Window.partitionBy(df.name,df.flag).orderBy(df.date)
res = df.withColumn('grp',f.row_number().over(w1)-f.row_number().over(w2))
#Window definition for streak
w3 = Window.partitionBy(res.name,res.flag,res.grp).orderBy(res.date)
streak_res = res.withColumn('streak_0',f.when(res.flag == 1,0).otherwise(f.row_number().over(w3))) \
.withColumn('streak_1',f.when(res.flag == 0,0).otherwise(f.row_number().over(w3)))
streak_res.show()
Upvotes: 7