Reputation: 6206
For my specific requirement, I want to write a UDAF, which simply collect all the input rows.
The input is a two column rows, Double Type;
The Intermediate Schema, "I thought", is ArrayList (correct me if I was wrong)
The returned data type is the ArrayList
I wrote an "idea" of my UDAF, but I wish someone to help me finish it.
class CollectorUDAF() extends UserDefinedAggregateFunction {
// Input Data Type Schema
def inputSchema: StructType = StructType(Array(StructField("value", DoubleType), StructField("y", DoubleType)))
// Intermediate Schema
def bufferSchema = util.ArrayList[Array(StructField("value", DoubleType), StructField("y", DoubleType)]
// Returned Data Type .
def dataType: DataType = util.ArrayList[Array(StructField("value", DoubleType), StructField("y", DoubleType)]
// Self-explaining
def deterministic = true
// This function is called whenever key changes
def initialize(buffer: MutableAggregationBuffer) = {
}
// Iterate over each entry of a group
def update(buffer: MutableAggregationBuffer, input: Row) = {
}
// Called after all the entries are exhausted.
def evaluate(buffer: Row) = {
}
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
}
}
Upvotes: 4
Views: 2954
Reputation: 41957
If I understood your question correct, following shall be your solution :
class CollectorUDAF() extends UserDefinedAggregateFunction {
// Input Data Type Schema
def inputSchema: StructType = new StructType().add("value", DataTypes.DoubleType).add("y", DataTypes.DoubleType)
// Intermediate Schema
val bufferFields : util.ArrayList[StructField] = new util.ArrayList[StructField]
val bufferStructField : StructField = DataTypes.createStructField("array", DataTypes.createArrayType(DataTypes.StringType, true), true)
bufferFields.add(bufferStructField)
def bufferSchema: StructType = DataTypes.createStructType(bufferFields)
// Returned Data Type .
def dataType: DataType = DataTypes.createArrayType(DataTypes.DoubleType)
// Self-explaining
def deterministic = true
// This function is called whenever key changes
def initialize(buffer: MutableAggregationBuffer) = {
buffer(0, new java.util.ArrayList[Double])
}
// Iterate over each entry of a group
def update(buffer: MutableAggregationBuffer, input: Row) = {
val DoubleList: util.ArrayList[Double] = new util.ArrayList[Double](buffer.getList(0))
DoubleList.add(input.getDouble(0))
DoubleList.add(input.getDouble(1))
buffer.update(0, DoubleList)
}
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1.update(0, buffer1.getList(0).toArray() ++ buffer2.getList(0).toArray())
}
// Called after all the entries are exhausted.
def evaluate(buffer: Row) = {
buffer.getList(0).toArray()
}
}
Upvotes: 5