Reputation: 183
I have seen in other posts of this being done for dataframes: https://stackoverflow.com/a/52992212/4080521
But I am trying to figure out how I can write an udf for a cumulative product.
Assuming I have a very basic table
Input data:
+----+
| val|
+----+
| 1 |
| 2 |
| 3 |
+----+
If i want to take the sum of this I can simply do something like
sparkSession.createOrReplaceTempView("table")
spark.sql("""Select SUM(table.val) from table""").show(100, false)
and this simply works because SUM is a pre defined function.
How would I define something similar for multiplication (or even how can I implement sum in an UDF
myself)?
Trying the following
sparkSession.createOrReplaceTempView("_Period0")
val prod = udf((vals:Seq[Decimal]) => vals.reduce(_ * _))
spark.udf.register("prod",prod)
spark.sql("""Select prod(table.vals) from table""").show(100, false)
I get the following error:
Message: cannot resolve 'UDF(vals)' due to data type mismatch: argument 1 requires array<decimal(38,18)> type, however, 'table.vals' is of decimal(28,14)
Obviously each specific cell is not an array, but it seems the udf needs to take in an array to perform the aggregation. Is it even possible with spark sql?
Upvotes: 1
Views: 319
Reputation: 7928
You can implement it through UserDefinedAggregateFunction
You need to define several functions to work with the input and the buffer values.
Quick example for the product function using just doubles as type:
import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
class myUDAF extends UserDefinedAggregateFunction {
// inputSchema for the function
override def inputSchema: StructType = {
new StructType().add("val", DoubleType, nullable = true)
}
//Schema for the inner UDAF buffer, in the product case, you just need an accumulator
override def bufferSchema: StructType = StructType(StructField("accumulated", DoubleType) :: Nil)
//OutputDataType
override def dataType: DataType = DoubleType
override def deterministic: Boolean = true
//Initicla buffer value 1 for product
override def initialize(buffer: MutableAggregationBuffer) = buffer(0) = 1.0
//How to update the buffer, for product you just need to perform a product between the two elements (buffer & input)
override def update(buffer: MutableAggregationBuffer, input: Row) = {
buffer(0) = buffer.getAs[Double](0) * input.getAs[Double](0)
}
//Merge results with the previous buffered value (product as well here)
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getAs[Double](0) * buffer2.getAs[Double](0)
}
//Function on how to return the value
override def evaluate(buffer: Row) = buffer.getAs[Double](0)
}
Then you can register the function as you would do with any other UDF:
spark.udf.register("prod", new myUDAF)
RESULT
scala> spark.sql("Select prod(val) from table").show
+-----------+
|myudaf(val)|
+-----------+
| 6.0|
+-----------+
You can find further documentation here
Upvotes: 1