Steven Hsu
Steven Hsu

Reputation: 183

Cumulative product UDF for Spark SQL

I have seen in other posts of this being done for dataframes: https://stackoverflow.com/a/52992212/4080521

But I am trying to figure out how I can write an udf for a cumulative product.

Assuming I have a very basic table

Input data:
+----+
| val|
+----+
| 1  |
| 2  |
| 3  |
+----+

If i want to take the sum of this I can simply do something like

sparkSession.createOrReplaceTempView("table")
spark.sql("""Select SUM(table.val) from table""").show(100, false)

and this simply works because SUM is a pre defined function.

How would I define something similar for multiplication (or even how can I implement sum in an UDF myself)?

Trying the following

sparkSession.createOrReplaceTempView("_Period0")

val prod = udf((vals:Seq[Decimal]) => vals.reduce(_ * _))
spark.udf.register("prod",prod)

spark.sql("""Select prod(table.vals) from table""").show(100, false)

I get the following error:

Message: cannot resolve 'UDF(vals)' due to data type mismatch: argument 1 requires array<decimal(38,18)> type, however, 'table.vals' is of decimal(28,14)

Obviously each specific cell is not an array, but it seems the udf needs to take in an array to perform the aggregation. Is it even possible with spark sql?

Upvotes: 1

Views: 319

Answers (1)

SCouto
SCouto

Reputation: 7928

You can implement it through UserDefinedAggregateFunction You need to define several functions to work with the input and the buffer values.

Quick example for the product function using just doubles as type:

  import org.apache.spark.sql.expressions.MutableAggregationBuffer
  import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
  import org.apache.spark.sql.Row
  import org.apache.spark.sql.types._


    class myUDAF extends UserDefinedAggregateFunction {

      // inputSchema for the function
      override def inputSchema: StructType = {
        new StructType().add("val", DoubleType, nullable = true)
      }

     //Schema for the inner UDAF buffer, in the product case, you just need an accumulator
     override def bufferSchema: StructType = StructType(StructField("accumulated", DoubleType) :: Nil)

    //OutputDataType
    override def dataType: DataType = DoubleType

    override def deterministic: Boolean = true

    //Initicla buffer value 1 for product
    override def initialize(buffer: MutableAggregationBuffer) = buffer(0) = 1.0

    //How to update the buffer, for product you just need to perform a product between the two elements (buffer & input)
    override def update(buffer: MutableAggregationBuffer, input: Row) = {
        buffer(0) = buffer.getAs[Double](0) * input.getAs[Double](0)
      }

      //Merge results with the previous buffered value (product as well here)
      override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
        buffer1(0) = buffer1.getAs[Double](0) * buffer2.getAs[Double](0)
      }

      //Function on how to return the value
      override def evaluate(buffer: Row) = buffer.getAs[Double](0)

    }

Then you can register the function as you would do with any other UDF:

spark.udf.register("prod", new myUDAF)

RESULT

scala> spark.sql("Select prod(val) from table").show
+-----------+
|myudaf(val)|
+-----------+
|        6.0|
+-----------+

You can find further documentation here

Upvotes: 1

Related Questions