Reputation: 1552
I have a simple dataframe as an example:
val someDF = Seq(
(1, "A"),
(2, "A"),
(3, "A"),
(4, "B"),
(5, "B"),
(6, "A"),
(7, "A"),
(8, "A")
).toDF("t", "state")
// this part is half pseudocode
someDF.aggregate((acc, cur) => {
if (acc.last.state != cur.state) {
acc.add(cur)
}
}, List()).show(truncate=false)
"t" column represents points in time and "state" column represents the state at that point in time.
What I wish to find is the first time where each change occurs plus the first row, as in:
(1, "A")
(4, "B")
(6, "A")
I looked at the solutions in SQL too but they involve complex self-joins and window functions which I don't completely understand, but an SQL solution is OK too.
There are numerous functions in spark (fold, aggregate, reduce ..) that I feel which can do this, but I couldn't grasp the differences since I'm new to spark concepts like partitioning, and it's a bit tricky if the partitioning could affect the results.
Upvotes: 0
Views: 28
Reputation: 42342
You can use the window function lag
for comparing with the previous row, and row_number
for checking whether it's the first row:
import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions.Window
val result = someDF.withColumn(
"change",
lag("state", 1).over(Window.orderBy("t")) =!= col("state") ||
row_number().over(Window.orderBy("t")) === 1
).filter("change").drop("change")
result.show
+---+-----+
| t|state|
+---+-----+
| 1| A|
| 4| B|
| 6| A|
+---+-----+
For an SQL solution:
someDF.createOrReplaceTempView("mytable")
val result = spark.sql("""
select t, state
from (
select
t, state,
lag(state) over (order by t) != state or row_number() over (order by t) = 1 as change
from mytable
)
where change
""")
Upvotes: 1