zwang
zwang

Reputation: 357

Get the first row that matches some condition over a window in PySpark

To give an example, suppose we have a stream of user actions as follows:

from pyspark.sql import *
spark = SparkSession.builder.appName('test').master('local[8]').getOrCreate()

df = spark.sparkContext.parallelize([
    Row(user=1, action=1, time=1),
    Row(user=1, action=1, time=2),
    Row(user=2, action=1, time=3),
    Row(user=1, action=2, time=4),
    Row(user=2, action=2, time=5),
    Row(user=2, action=2, time=6),
    Row(user=1, action=1, time=7),
    Row(user=2, action=1, time=8),
]).toDF()
df.show()

The dataframe looks like:

+----+------+----+
|user|action|time|
+----+------+----+
|   1|     1|   1|
|   1|     1|   2|
|   2|     1|   3|
|   1|     2|   4|
|   2|     2|   5|
|   2|     2|   6|
|   1|     1|   7|
|   2|     1|   8|
+----+------+----+

Then, I want to add a column next_alt_time to each row, giving the time when user changes action type in the following rows. For the input above, the output should be:

+----+------+----+-------------+
|user|action|time|next_alt_time|
+----+------+----+-------------+
|   1|     1|   1|            4|
|   1|     1|   2|            4|
|   2|     1|   3|            5|
|   1|     2|   4|            7|
|   2|     2|   5|            8|
|   2|     2|   6|            8|
|   1|     1|   7|         null|
|   2|     1|   8|         null|
+----+------+----+-------------+

I know I can create a window like this:

wnd = Window().partitionBy('user').orderBy('time').rowsBetween(1, Window.unboundedFollowing)

But then I don't know how to impose a condition over the window and select the first row that has a different action than current row, over the window defined above.

Upvotes: 1

Views: 3321

Answers (1)

mck
mck

Reputation: 42422

Here's how to do it. Spark cannot keep the dataframe order, but if you check the rows one by one, you can confirm that it's giving your expected answer:

from pyspark.sql import Row
from pyspark.sql.window import Window
import pyspark.sql.functions as F

df = spark.sparkContext.parallelize([
    Row(user=1, action=1, time=1),
    Row(user=1, action=1, time=2),
    Row(user=2, action=1, time=3),
    Row(user=1, action=2, time=4),
    Row(user=2, action=2, time=5),
    Row(user=2, action=2, time=6),
    Row(user=1, action=1, time=7),
    Row(user=2, action=1, time=8),
]).toDF()

win = Window().partitionBy('user').orderBy('time')

df = df.withColumn('new_action', F.lag('action').over(win) != F.col('action'))
df = df.withColumn('new_action_time', F.when(F.col('new_action'), F.col('time')))
df = df.withColumn('next_alt_time', F.first('new_action', ignorenulls=True).over(win.rowsBetween(1, Window.unboundedFollowing)))

df.show()

+----+------+----+----------+---------------+-------------+
|user|action|time|new_action|new_action_time|next_alt_time|
+----+------+----+----------+---------------+-------------+
|   1|     1|   1|      null|           null|            4|
|   1|     1|   2|     false|           null|            4|
|   1|     2|   4|      true|              4|            7|
|   1|     1|   7|      true|              7|         null|
|   2|     1|   3|      null|           null|            5|
|   2|     2|   5|      true|              5|            8|
|   2|     2|   6|     false|           null|            8|
|   2|     1|   8|      true|              8|         null|
+----+------+----+----------+---------------+-------------+

Upvotes: 2

Related Questions