uylmz
uylmz

Reputation: 1552

How to find change occurance points in a Spark dataframe

I have a simple dataframe as an example:

val someDF = Seq(
  (1, "A"),
  (2, "A"),
  (3, "A"),
  (4, "B"),
  (5, "B"),
  (6, "A"),
  (7, "A"),
  (8, "A")
).toDF("t", "state")

// this part is half pseudocode
someDF.aggregate((acc, cur) => {
    if (acc.last.state != cur.state) {
        acc.add(cur)
    }
}, List()).show(truncate=false)

"t" column represents points in time and "state" column represents the state at that point in time.

What I wish to find is the first time where each change occurs plus the first row, as in:

(1, "A")
(4, "B")
(6, "A")

I looked at the solutions in SQL too but they involve complex self-joins and window functions which I don't completely understand, but an SQL solution is OK too.

There are numerous functions in spark (fold, aggregate, reduce ..) that I feel which can do this, but I couldn't grasp the differences since I'm new to spark concepts like partitioning, and it's a bit tricky if the partitioning could affect the results.

Upvotes: 0

Views: 28

Answers (1)

mck
mck

Reputation: 42342

You can use the window function lag for comparing with the previous row, and row_number for checking whether it's the first row:

import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions.Window

val result = someDF.withColumn(
    "change", 
    lag("state", 1).over(Window.orderBy("t")) =!= col("state") || 
    row_number().over(Window.orderBy("t")) === 1
).filter("change").drop("change")

result.show
+---+-----+
|  t|state|
+---+-----+
|  1|    A|
|  4|    B|
|  6|    A|
+---+-----+

For an SQL solution:

someDF.createOrReplaceTempView("mytable")
val result = spark.sql("""
    select t, state 
    from (
        select 
            t, state, 
            lag(state) over (order by t) != state or row_number() over (order by t) = 1 as change 
       from mytable
    ) 
    where change
""")

Upvotes: 1

Related Questions