Reputation: 682
I'm trying to calculate an aggregate function like sum() over a window function with a range, BUT I only want to include every Nth row. AND what it skips should be relative to the front of the window (Always include the first row in the window)
//val df = some Dataframe {symbol,datetime,metric}
val baseWin = Window.partitionBy("symbol").orderBy("datetime")
//This is a plain sum over the window
val plain = sum(col("metric")).over(baseWin.rowsBetween(-12,0))
//This is ALMOST what I want (every 3rd) BUT isn't relative to the window
val almost = sum(when(col("datetime")/lit(DAY) %3 === 0, col("metric")).over(baseWin.rowsBetween(-12,0))
Upvotes: 4
Views: 669
Reputation: 35219
You can use lag
. With range defined as:
scala> (0 to 12 by 3)
res1: scala.collection.immutable.Range = Range(0, 3, 6, 9, 12)
you can sum all lags (defaulting to 0):
val almost = (0 to 12 by 3).map(lag($"metric", _, 0).over(baseWindow)).reduce(_ + _)
Example:
val df = spark.range(24).toDF("metric").withColumn("group", $"metric" > 12)
val baseWindow = Window.partitionBy("group").orderBy("metric")
df.withColumn("almost", almost).show
// +------+-----+------+
// |metric|group|almost|
// +------+-----+------+
// | 13| true| 13| 13
// | 14| true| 14| 14
// | 15| true| 15| 15
// | 16| true| 29| 16 + 13
// | 17| true| 31| 17 + 14
// | 18| true| 33| 18 + 14
// | 19| true| 48| 19 + 16 + 13
// | 20| true| 51| 20 + 17 + 14
// | 21| true| 54| 21 + 18 + 15
// | 22| true| 70| 22 + 19 + 16 + 13
// | 23| true| 74| 23 + 20 + 17 + 14
// | 0|false| 0| ...
// | 1|false| 1| 1
// | 2|false| 2| 2
// | 3|false| 3| 3
// | 4|false| 5| 4 + 1
// | 5|false| 7| 5 + 2
// | 6|false| 9| 6 + 3
// | 7|false| 12| 7 + 4 + 1
// | 8|false| 15| 8 + 5 + 2
// +------+-----+------+
Upvotes: 5