Reputation: 1611
I have a requirement where a dataframe is sorted by col1 (timestamp) and I need to filter by col2.
Any row where col2 value is less than col2 value of the previous row, I need to filter out that row. Result should be monotonically increasing col2 value.
Note that this is not just about two rows.
For example, let's say the value of col2 for 4 rows are 4,2,3,5. The result should be 4,5 as both the 2nd and 3rd row are less than 4 (first row value).
val input = Seq(
(1,4), (2,2), (3,3), (4,5), (5, 1), (6, 9), (7, 6)
).toDF("timestamp", "value")
scala> input.show
+---------+-----+
|timestamp|value|
+---------+-----+
| 1| 4|
| 2| 2|
| 3| 3|
| 4| 5|
| 5| 1|
| 6| 9|
| 7| 6|
+---------+-----+
val expected = Seq((1,4), (4,5), (6, 9)).toDF("timestamp", "value")
scala> expected.show
+---------+-----+
|timestamp|value|
+---------+-----+
| 1| 4|
| 4| 5|
| 6| 9|
+---------+-----+
Please note that:
Generically speaking, is there a way to filter rows based on comparison of value of one row with value in the previous rows?
Upvotes: 1
Views: 928
Reputation: 27373
checking equality with the running maximum should do the trick:
val input = Seq((1,4), (2,2), (3,3), (4,5), (5, 1), (6, 9), (7, 6)).toDF("timestamp", "value")
input.show()
+---------+-----+
|timestamp|value|
+---------+-----+
| 1| 4|
| 2| 2|
| 3| 3|
| 4| 5|
| 5| 1|
| 6| 9|
| 7| 6|
+---------+-----+
input
.withColumn("max",max($"value").over(Window.orderBy($"timestamp")))
.where($"value"===$"max").drop($"max")
.show()
+---------+-----+
|timestamp|value|
+---------+-----+
| 1| 4|
| 4| 5|
| 6| 9|
+---------+-----+
Upvotes: 1
Reputation: 74669
I think what you're after is called running maximum (after running total). That always leads me to use windowed aggregation.
// I made the input a bit more tricky
val input = Seq(
(1,4), (2,2), (3,3), (4,5), (5, 1), (6, 9), (7, 6)
).toDF("timestamp", "value")
scala> input.show
+---------+-----+
|timestamp|value|
+---------+-----+
| 1| 4|
| 2| 2|
| 3| 3|
| 4| 5|
| 5| 1|
| 6| 9|
| 7| 6|
+---------+-----+
I'm aiming at the following expected result. Correct me if I'm wrong.
val expected = Seq((1,4), (4,5), (6, 9)).toDF("timestamp", "value")
scala> expected.show
+---------+-----+
|timestamp|value|
+---------+-----+
| 1| 4|
| 4| 5|
| 6| 9|
+---------+-----+
The trick to use for "running" problems is to use rangeBetween
when defining a window specification.
import org.apache.spark.sql.expressions.Window
val ts = Window
.orderBy("timestamp")
.rangeBetween(Window.unboundedPreceding, Window.currentRow)
With the window spec you filter out what you want to get rid of from the result and you're done.
val result = input
.withColumn("running_max", max("value") over ts)
.where($"running_max" === $"value")
.select("timestamp", "value")
scala> result.show
18/05/29 22:09:18 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
+---------+-----+
|timestamp|value|
+---------+-----+
| 1| 4|
| 4| 5|
| 6| 9|
+---------+-----+
As you can see it's not very efficient since it only uses a single partition (that leads to a poor single-threaded execution and so not much difference from running the experiment on a single machine).
I think we could partition the input calculate the running maximum partially and then union the partial results and re-run the running maximum calculation again. Just a thought I have not tried out myself.
Upvotes: 1