Reputation: 1180
I want to sum values if one column is 'relative'
and restart the sum if it is 'absolute'
Here I defined my dataFrame:
val df = sc.parallelize(Seq(
(1, "2018-02-21", 'relative, 3.00),
(1, "2018-02-22", 'relative, 4.00),
(1, "2018-02-23", 'absolute, 5.00),
(1, "2018-02-24", 'relative, 6.00),
(1, "2018-02-26", 'relative, 8.00)
)).toDF("id", "date", "updateType", "value")
I defined a UDF to know when to sum and when not to. I want to order by date and then sum the values when I have to or put the absolute value
val computeValue = udf((previous: java.math.BigDecimal, value: java.math.BigDecimal, updateType: String) => {
updateType match {
case "absolute" => value
case "relative" => previous.add(value)
case _ => previous
}
})
val w = Window
.partitionBy($"id")
.orderBy($"date")
val result = df.select(
$"id",
$"date",
computeValue(
lag($"value", 1, 0).over(w),
$"value",
$"updateType"
).alias("sumValue")
)
This actually returns:
+---+----------+---------+
| id| date| sumValue|
+---+----------+---------+
| 1|2018-02-21|3.000 |
| 1|2018-02-22|7.000 |
| 1|2018-02-23|5.00 |
| 1|2018-02-24|11.00 |
| 1|2018-02-26|14.00 |
+---+----------+---------+
And i'm looking for:
+---+----------+---------+
| id| date| sumValue|
+---+----------+---------+
| 1|2018-02-21|3.000 |
| 1|2018-02-22|7.000 |
| 1|2018-02-23|5.00 |
| 1|2018-02-24|11.00 |
| 1|2018-02-26|19.00 |
+---+----------+---------+
Upvotes: 1
Views: 1384
Reputation: 1180
The answer is to use UDAF (User Defined Aggregation Function) for this kind of operations.
// Init aggregation function to compute values
val computeValue = new ComputeValue
val w = Window
.partitionBy($"id")
.orderBy($"date")
val result = df.select(
$"id",
$"date",
computeValue(
$"value",
$"updateType"
).over(w).alias("sumValue")
)
Where ComputeValue UDAF is:
class ComputeValue extends UserDefinedAggregateFunction {
// Each row will be of type value: Double - update_type: String
override def inputSchema: org.apache.spark.sql.types.StructType =
StructType(
StructField("value", DoubleType) ::
StructField("update_type", StringType) :: Nil)
// Another column where I will keep internal calculations
override def bufferSchema: StructType = StructType(
StructField("value", DoubleType) :: Nil
)
override def dataType: DataType = DoubleType
override def deterministic: Boolean = true
override def initialize(buffer: MutableAggregationBuffer): Unit = buffer(0) = 0.0
// This is how to update your buffer schema given an input.
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0) = computeValue(buffer, input)
}
// This is how to merge two objects with the bufferSchema type.
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = computeValue(buffer1, buffer2)
}
// Get the final value of your bufferSchema.
override def evaluate(buffer: Row): Any = {
buffer.getDouble(0)
}
private def computeValue(buffer: MutableAggregationBuffer, row: Row): Double = {
val updateType: String = row.getAs[String](1)
val prev: Double = buffer.getDouble(0)
val current: Double = row.getAs[Double](0)
updateType match {
case "relative" => prev + current
case "absolute" => current
case _ => current
}
}
}
Upvotes: 2