Barranka
Barranka

Reputation: 21047

Pyspark: Calculate streak of consecutive observations

I have a Spark (2.4.0) data frame with a column that has just two values (either 0 or 1). I need to calculate the streak of consecutive 0s and 1s in this data, resetting the streak to zero if the value changes.

An example:

from pyspark.sql import (SparkSession, Window)
from pyspark.sql.functions import (to_date, row_number, lead, col)

spark = SparkSession.builder.appName('test').getOrCreate()

# Create dataframe
df = spark.createDataFrame([
    ('2018-01-01', 'John', 0, 0),
    ('2018-01-01', 'Paul', 1, 0),
    ('2018-01-08', 'Paul', 3, 1),
    ('2018-01-08', 'Pete', 4, 0),
    ('2018-01-08', 'John', 3, 0),
    ('2018-01-15', 'Mary', 6, 0),
    ('2018-01-15', 'Pete', 6, 0),
    ('2018-01-15', 'John', 6, 1),
    ('2018-01-15', 'Paul', 6, 1),
], ['str_date', 'name', 'value', 'flag'])

df.orderBy('name', 'str_date').show()
## +----------+----+-----+----+
## |  str_date|name|value|flag|
## +----------+----+-----+----+
## |2018-01-01|John|    0|   0|
## |2018-01-08|John|    3|   0|
## |2018-01-15|John|    6|   1|
## |2018-01-15|Mary|    6|   0|
## |2018-01-01|Paul|    1|   0|
## |2018-01-08|Paul|    3|   1|
## |2018-01-15|Paul|    6|   1|
## |2018-01-08|Pete|    4|   0|
## |2018-01-15|Pete|    6|   0|
## +----------+----+-----+----+

With this data, I'd like to calculate the streak of consecutive zeros and ones, ordered by date and "windowed" by name:

# Expected result:
## +----------+----+-----+----+--------+--------+
## |  str_date|name|value|flag|streak_0|streak_1|
## +----------+----+-----+----+--------+--------+
## |2018-01-01|John|    0|   0|       1|       0|
## |2018-01-08|John|    3|   0|       2|       0|
## |2018-01-15|John|    6|   1|       0|       1|
## |2018-01-15|Mary|    6|   0|       1|       0|
## |2018-01-01|Paul|    1|   0|       1|       0|
## |2018-01-08|Paul|    3|   1|       0|       1|
## |2018-01-15|Paul|    6|   1|       0|       2|
## |2018-01-08|Pete|    4|   0|       1|       0|
## |2018-01-15|Pete|    6|   0|       2|       0|
## +----------+----+-----+----+--------+--------+

Of course, I would need the streak to reset itself to zero if the 'flag' changes.

Is there a way of doing this?

Upvotes: 4

Views: 4441

Answers (2)

Tim
Tim

Reputation: 3417

There is a more intuitive solution without the use of row_number() if you already have a natural ordering column (str_date) in this case.

In short, to find streak of 1's, just use the

  1. cumulative sum of the flag,
  2. then, multiplied by the flag.

To find streak of 0's, invert the flag first and then do the same for streak of 1's.

First we define a function to calculate cumulative sum:

from pyspark.sql import Window 
from pyspark.sql import functions as f

def cum_sum(df, new_col_name, partition_cols, order_col, value_col):
    windowval = (Window.partitionBy(partition_cols).orderBy(order_col)
             .rowsBetween(Window.unboundedPreceding, 0))
    return df.withColumn(new_col_name, f.sum(value_col).over(windowval))

Note the use of rowsBetween (instead of rangeBetween). This is important to get the correct cumulative sum when there are duplicate values in the order column.

Calculate streak of 1's

df = cum_sum(df, 
             new_col_name='1_group', 
             partition_cols='name', 
             order_col='str_date',
             value_col='flag')
df = df.withColumn('streak_1', f.col('flag')*f.col('1_group'))

Calculate streak of 0's

df = df.withColumn('flag_inverted', 1-f.col('flag'))

df = cum_sum(df, 
             new_col_name='0_group', 
             partition_cols='name', 
             order_col='str_date',
             value_col='flag_inverted')
df = df.withColumn('streak_0', f.col('flag_inverted')*f.col('0_group'))

Upvotes: 2

Vamsi Prabhala
Vamsi Prabhala

Reputation: 49260

This would require a difference in row numbers approach to first group consecutive rows with the same value and then using a ranking approach among the groups.

from pyspark.sql import Window 
from pyspark.sql import functions as f
#Windows definition
w1 = Window.partitionBy(df.name).orderBy(df.date)
w2 = Window.partitionBy(df.name,df.flag).orderBy(df.date)

res = df.withColumn('grp',f.row_number().over(w1)-f.row_number().over(w2))
#Window definition for streak
w3 = Window.partitionBy(res.name,res.flag,res.grp).orderBy(res.date)
streak_res = res.withColumn('streak_0',f.when(res.flag == 1,0).otherwise(f.row_number().over(w3))) \
                .withColumn('streak_1',f.when(res.flag == 0,0).otherwise(f.row_number().over(w3)))
streak_res.show()

Upvotes: 7

Related Questions