TomTom101
TomTom101

Reputation: 6892

UDAF in Spark with multiple input columns

I am trying to develop a user defined aggregate function that computes a linear regression on a row of numbers. I have successfully done a UDAF that calculates confidence intervals of means (with a lot trial and error and SO!).

Here's what actually runs for me already:

import org.apache.spark.sql._
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{StructType, StructField, DoubleType, LongType, DataType, ArrayType}

case class RegressionData(intercept: Double, slope: Double)

class Regression  {

  import org.apache.commons.math3.stat.regression.SimpleRegression

  def roundAt(p: Int)(n: Double): Double = { val s = math pow (10, p); (math round n * s) / s }

  def getRegression(data: List[Long]): RegressionData = {
    val regression: SimpleRegression  = new SimpleRegression()
    data.view.zipWithIndex.foreach { d =>
        regression.addData(d._2.toDouble, d._1.toDouble)
    }

    RegressionData(roundAt(3)(regression.getIntercept()), roundAt(3)(regression.getSlope()))
  }
}


class UDAFRegression extends UserDefinedAggregateFunction {

  import java.util.ArrayList

  def deterministic = true

  def inputSchema: StructType =
    new StructType().add("units", LongType)

  def bufferSchema: StructType =
    new StructType().add("buff", ArrayType(LongType))


  def dataType: DataType =
    new StructType()
      .add("intercept", DoubleType)
      .add("slope", DoubleType)

  def initialize(buffer: MutableAggregationBuffer) = {
    buffer.update(0, new ArrayList[Long]())
  }

  def update(buffer: MutableAggregationBuffer, input: Row) = {
    val longList: ArrayList[Long]  = new ArrayList[Long](buffer.getList(0))
    longList.add(input.getLong(0));
    buffer.update(0, longList);

  }

  def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
    val longList: ArrayList[Long] = new ArrayList[Long](buffer1.getList(0))
    longList.addAll(buffer2.getList(0))

    buffer1.update(0, longList)
  }


  def evaluate(buffer: Row) = {
    import scala.collection.JavaConverters._
    val list = buffer.getList(0).asScala.toList
    val regression = new Regression
    regression.getRegression(list)
  }
}

However the datasets do not come in order, which is obviously very important here. Hence instead of regression($"longValue") I need to a second param regression($"longValue", $"created_day"). created_day is a sql.types.DateType.

I am pretty confused by DataTypes, StructTypes and what-not and due to the lack of examples on the web, I got stuck w/ my trial and order attempts here.

What would my bufferSchema look like?

Are those StructTypes overhead in my case? Wouldn't a (mutable) Map just do? Is MapType actually immutable and isn't this rather pointless to be a buffer type?

What would my inputSchema look like?

Does this have to match the type I retrieve in update() via in my case input.getLong(0)?

Is there a standard way how to reset the buffer in initialize()

I have seen buffer.update(0, 0.0) (when it contains Doubles, obviously), buffer(0) = new WhatEver() and I think even buffer = Nil. Does any of these make a difference?

How to update data?

The example above seems over complicated. I was expecting to be able to do sth. like buffer += input.getLong(0) -> input.getDate(1). Can I expect to access the input this way

How to merge data?

Can I just leave the function block empty like def merge(…) = {}?

The challenge to sort that buffer in evaluate() is sth. I should be able to figure out, although I am still interested in the most elegant ways of how you guys do this (in a fraction of the time).

Bonus question: What role does dataType play?

I return a case class, not the StructType as defined in dataType which does not seem to be an issue. Or is it working since it happens to match my case class?

Upvotes: 0

Views: 2330

Answers (1)

David Griffin
David Griffin

Reputation: 13927

Maybe this will clear things up.

The UDAF APIs work on DataFrame Columns. Everything you are doing has to get serialized just like all the other Columns in the DataFrame. As you note, the only support MapType is immutable, because this is the only thing you can put in a Column. With immutable collections, you just create a new collection that contains the old collection plus a value:

var map = Map[Long,Long]()
map = map + (0L -> 1234L)
map = map + (1L -> 4567L)

Yes, just like working with any DataFrame, your types have to match. Do buffer.getInt(0) when there's really a LongType there is going to be a problem.

There's no standard way to reset the buffer because other than whatever makes sense for your data type / use case. Maybe zero is actually last month's balanace; maybe zero is a running average from another dataset; maybe zero is an null or an empty string or maybe zero is really zero.

merge is an optimization that only happens in certain circumstances, if I remember correctly -- a way to sub-total that the SQL optimization may use if the circumstances warrant it. I just use the same function I use for update.

A case class will automatically get converted to the appropriate schema, so for the bonus question the answer is yes, it's because the schemas match. Change the dataType to not match, you will get an error.

Upvotes: 2

Related Questions