Reputation: 6892
I am trying to develop a user defined aggregate function that computes a linear regression on a row of numbers. I have successfully done a UDAF that calculates confidence intervals of means (with a lot trial and error and SO!).
Here's what actually runs for me already:
import org.apache.spark.sql._
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{StructType, StructField, DoubleType, LongType, DataType, ArrayType}
case class RegressionData(intercept: Double, slope: Double)
class Regression {
import org.apache.commons.math3.stat.regression.SimpleRegression
def roundAt(p: Int)(n: Double): Double = { val s = math pow (10, p); (math round n * s) / s }
def getRegression(data: List[Long]): RegressionData = {
val regression: SimpleRegression = new SimpleRegression()
data.view.zipWithIndex.foreach { d =>
regression.addData(d._2.toDouble, d._1.toDouble)
}
RegressionData(roundAt(3)(regression.getIntercept()), roundAt(3)(regression.getSlope()))
}
}
class UDAFRegression extends UserDefinedAggregateFunction {
import java.util.ArrayList
def deterministic = true
def inputSchema: StructType =
new StructType().add("units", LongType)
def bufferSchema: StructType =
new StructType().add("buff", ArrayType(LongType))
def dataType: DataType =
new StructType()
.add("intercept", DoubleType)
.add("slope", DoubleType)
def initialize(buffer: MutableAggregationBuffer) = {
buffer.update(0, new ArrayList[Long]())
}
def update(buffer: MutableAggregationBuffer, input: Row) = {
val longList: ArrayList[Long] = new ArrayList[Long](buffer.getList(0))
longList.add(input.getLong(0));
buffer.update(0, longList);
}
def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
val longList: ArrayList[Long] = new ArrayList[Long](buffer1.getList(0))
longList.addAll(buffer2.getList(0))
buffer1.update(0, longList)
}
def evaluate(buffer: Row) = {
import scala.collection.JavaConverters._
val list = buffer.getList(0).asScala.toList
val regression = new Regression
regression.getRegression(list)
}
}
However the datasets do not come in order, which is obviously very important here. Hence instead of regression($"longValue")
I need to a second param regression($"longValue", $"created_day")
. created_day
is a sql.types.DateType
.
I am pretty confused by DataTypes, StructTypes and what-not and due to the lack of examples on the web, I got stuck w/ my trial and order attempts here.
bufferSchema
look like?Are those StructTypes overhead in my case? Wouldn't a (mutable) Map
just do? Is MapType
actually immutable and isn't this rather pointless to be a buffer type?
inputSchema
look like?Does this have to match the type I retrieve in update()
via in my case input.getLong(0)
?
initialize()
I have seen buffer.update(0, 0.0)
(when it contains Doubles, obviously), buffer(0) = new WhatEver()
and I think even buffer = Nil
. Does any of these make a difference?
The example above seems over complicated. I was expecting to be able to do sth. like buffer += input.getLong(0) -> input.getDate(1)
.
Can I expect to access the input this way
Can I just leave the function block empty like
def merge(…) = {}
?
The challenge to sort that buffer in evaluate()
is sth. I should be able to figure out, although I am still interested in the most elegant ways of how you guys do this (in a fraction of the time).
dataType
play?I return a case class, not the StructType
as defined in dataType
which does not seem to be an issue. Or is it working since it happens to match my case class?
Upvotes: 0
Views: 2330
Reputation: 13927
Maybe this will clear things up.
The UDAF APIs
work on DataFrame
Columns
. Everything you are doing has to get serialized just like all the other Columns
in the DataFrame
. As you note, the only support MapType
is immutable, because this is the only thing you can put in a Column
. With immutable collections, you just create a new collection that contains the old collection plus a value:
var map = Map[Long,Long]()
map = map + (0L -> 1234L)
map = map + (1L -> 4567L)
Yes, just like working with any DataFrame
, your types have to match. Do buffer.getInt(0)
when there's really a LongType
there is going to be a problem.
There's no standard way to reset the buffer because other than whatever makes sense for your data type / use case. Maybe zero is actually last month's balanace; maybe zero is a running average from another dataset; maybe zero is an null
or an empty string or maybe zero is really zero.
merge
is an optimization that only happens in certain circumstances, if I remember correctly -- a way to sub-total that the SQL optimization may use if the circumstances warrant it. I just use the same function I use for update
.
A case class
will automatically get converted to the appropriate schema, so for the bonus question the answer is yes, it's because the schemas match. Change the dataType
to not match, you will get an error.
Upvotes: 2