Reputation: 313
I am trying to understand the Java Spark documentation. There is a section called Untyped User Defined Aggregate Functions which has some sample code that I am not able to understand. Here is the code:
package org.apache.spark.examples.sql;
// $example on:untyped_custom_aggregation$
import java.util.ArrayList;
import java.util.List;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
// $example off:untyped_custom_aggregation$
public class JavaUserDefinedUntypedAggregation {
// $example on:untyped_custom_aggregation$
public static class MyAverage extends UserDefinedAggregateFunction {
private StructType inputSchema;
private StructType bufferSchema;
public MyAverage() {
List<StructField> inputFields = new ArrayList<>();
inputFields.add(DataTypes.createStructField("inputColumn", DataTypes.LongType, true));
inputSchema = DataTypes.createStructType(inputFields);
List<StructField> bufferFields = new ArrayList<>();
bufferFields.add(DataTypes.createStructField("sum", DataTypes.LongType, true));
bufferFields.add(DataTypes.createStructField("count", DataTypes.LongType, true));
bufferSchema = DataTypes.createStructType(bufferFields);
}
// Data types of input arguments of this aggregate function
public StructType inputSchema() {
return inputSchema;
}
// Data types of values in the aggregation buffer
public StructType bufferSchema() {
return bufferSchema;
}
// The data type of the returned value
public DataType dataType() {
return DataTypes.DoubleType;
}
// Whether this function always returns the same output on the identical input
public boolean deterministic() {
return true;
}
// Initializes the given aggregation buffer. The buffer itself is a `Row` that in addition to
// standard methods like retrieving a value at an index (e.g., get(), getBoolean()), provides
// the opportunity to update its values. Note that arrays and maps inside the buffer are still
// immutable.
public void initialize(MutableAggregationBuffer buffer) {
buffer.update(0, 0L);
buffer.update(1, 0L);
}
// Updates the given aggregation buffer `buffer` with new input data from `input`
public void update(MutableAggregationBuffer buffer, Row input) {
if (!input.isNullAt(0)) {
long updatedSum = buffer.getLong(0) + input.getLong(0);
long updatedCount = buffer.getLong(1) + 1;
buffer.update(0, updatedSum);
buffer.update(1, updatedCount);
}
}
// Merges two aggregation buffers and stores the updated buffer values back to `buffer1`
public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
long mergedSum = buffer1.getLong(0) + buffer2.getLong(0);
long mergedCount = buffer1.getLong(1) + buffer2.getLong(1);
buffer1.update(0, mergedSum);
buffer1.update(1, mergedCount);
}
// Calculates the final result
public Double evaluate(Row buffer) {
return ((double) buffer.getLong(0)) / buffer.getLong(1);
}
}
// $example off:untyped_custom_aggregation$
public static void main(String[] args) {
SparkSession spark = SparkSession
.builder()
.appName("Java Spark SQL user-defined DataFrames aggregation example")
.getOrCreate();
// $example on:untyped_custom_aggregation$
// Register the function to access it
spark.udf().register("myAverage", new MyAverage());
Dataset<Row> df = spark.read().json("examples/src/main/resources/employees.json");
df.createOrReplaceTempView("employees");
df.show();
// +-------+------+
// | name|salary|
// +-------+------+
// |Michael| 3000|
// | Andy| 4500|
// | Justin| 3500|
// | Berta| 4000|
// +-------+------+
Dataset<Row> result = spark.sql("SELECT myAverage(salary) as average_salary FROM employees");
result.show();
// +--------------+
// |average_salary|
// +--------------+
// | 3750.0|
// +--------------+
// $example off:untyped_custom_aggregation$
spark.stop();
}
}
My doubts related to the above code are:
initialize
, update
and merge
?inputSchema
and bufferSchema
? I'm surprised they exist because they are never used to create any DataFrames at all. Are they supposed to be present in every UDF? If yes, then are they supposed to be the exact same name?inputSchema
and bufferSchema
not named getInputSchema()
and getBufferSchema()
? Why are there no setters of these variables?deterministic()
here? Please give a scenario when it would be useful to call this function.In general I want to know how to how to write a user defined aggregate function in Spark.
Upvotes: 1
Views: 4853
Reputation: 74779
Whenever I want to create a UDF, should I have the functions initialize, update and merge
UDF stands for user-defined function while the methods initialize
, update
, and merge
are for user-defined aggregate functions (aka UDAF).
A UDF is a function that works with a single row to (usually) produce one row (e.g. upper
function).
A UDAF is a function that works with zero or many rows to produce one row (e.g. count
aggregate function).
You certainly don't have to (and won't be able to) have the functions initialize
, update
and merge
for user-defined functions (UDFs).
Use any of the udf
functions to define and register a UDF.
val myUpper = udf { (s: String) => s.toUpperCase }
How to how to write a user defined aggregate function in Spark.
What is the significance of the variables
inputSchema
andbufferSchema
?
(Shameless plug: I've been describing UDAFs in Mastering Spark SQL book in UserDefinedAggregateFunction — Contract for User-Defined Untyped Aggregate Functions (UDAFs))
Quoting Untyped User-Defined Aggregate Functions:
// Data types of input arguments of this aggregate function def inputSchema: StructType = StructType(StructField("inputColumn", LongType) :: Nil) // Data types of values in the aggregation buffer def bufferSchema: StructType = { StructType(StructField("sum", LongType) :: StructField("count", LongType) :: Nil) }
In other words, inputSchema
is what you expect from the input while bufferSchema
is what you keep temporarily while doing aggregation.
Why are there no setters of these variables?
They are extension points that are managed by Spark.
What is the significance of the function called
deterministic()
here?
Quoting Untyped User-Defined Aggregate Functions:
// Whether this function always returns the same output on the identical input def deterministic: Boolean = true
Please give a scenario when it would be useful to call this function.
That's something I'm still working on and so can't answer today.
Upvotes: 6