Reputation: 7605
I am trying to reproduce the User Defined Aggregate Functions example provided at Spark SQL Guide.
The only change I am adding with respect of the original code is the DataFrame creation:
import org.apache.spark.sql.{Encoder, Encoders, SparkSession}
import org.apache.spark.sql.expressions.Aggregator
case class Employee(name: String, salary: Long)
case class Average(var sum: Long, var count: Long)
object MyAverage extends Aggregator[Employee, Average, Double] {
// A zero value for this aggregation. Should satisfy the property that any b + zero = b
def zero: Average = Average(0L, 0L)
// Combine two values to produce a new value. For performance, the function may modify `buffer`
// and return it instead of constructing a new object
def reduce(buffer: Average, employee: Employee): Average = {
buffer.sum += employee.salary
buffer.count += 1
buffer
}
// Merge two intermediate values
def merge(b1: Average, b2: Average): Average = {
b1.sum += b2.sum
b1.count += b2.count
b1
}
// Transform the output of the reduction
def finish(reduction: Average): Double = reduction.sum.toDouble / reduction.count
// Specifies the Encoder for the intermediate value type
def bufferEncoder: Encoder[Average] = Encoders.product
// Specifies the Encoder for the final output value type
def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}
val originalDF = Seq(
("Michael", 3000),
("Andy", 4500),
("Justin", 3500),
("Berta", 4000)
).toDF("name", "salary")
+-------+------+
|name |salary|
+-------+------+
|Michael|3000 |
|Andy |4500 |
|Justin |3500 |
|Berta |4000 |
+-------+------+
When I try to use this UDAFs with Spark SQL (Second option the documentation)
spark.udf.register("myAverage", functions.udaf(MyAverage))
originalDF.createOrReplaceTempView("employees")
val result = spark.sql("SELECT myAverage(salary) as average_salary FROM employees")
result.show()
Everything goes as expected:
+--------------+
|average_salary|
+--------------+
| 3750.0|
+--------------+
However, when I try to use the approach which converts the function to a TypedColumn
:
val averageSalary = MyAverage.toColumn.name("average_salary")
val result = originalDF.as[Employee].select(averageSalary)
result.show()
I am getting the following Exception:
Job aborted due to stage failure.
Caused by: NotSerializableException: org.apache.spark.sql.TypedColumn
Serialization stack:
- object not serializable (class: org.apache.spark.sql.TypedColumn, value: myaverage(knownnotnull(assertnotnull(input[0, $line24f4d7b3f7f54dfc89ae8e2757da4abf39.$read$$iw$$iw$$iw$$iw$$iw$$iw$Average, true])).sum AS sum, knownnotnull(assertnotnull(input[0, $line24f4d7b3f7f54dfc89ae8e2757da4abf39.$read$$iw$$iw$$iw$$iw$$iw$$iw$Average, true])).count AS count, newInstance(class $line24f4d7b3f7f54dfc89ae8e2757da4abf39.$read$$iw$$iw$$iw$$iw$$iw$$iw$Average), boundreference()) AS average_salary)
- field (class: $line24f4d7b3f7f54dfc89ae8e2757da4abf39.$read$$iw$$iw$$iw$$iw$$iw$$iw, name: averageSalary, type: class org.apache.spark.sql.TypedColumn)
- object (class $line24f4d7b3f7f54dfc89ae8e2757da4abf39.$read$$iw$$iw$$iw$$iw$$iw$$iw, $line24f4d7b3f7f54dfc89ae8e2757da4abf39.$read$$iw$$iw$$iw$$iw$$iw$$iw@1254d4c6)
- field (class: $line24f4d7b3f7f54dfc89ae8e2757da4abf39.$read$$iw$$iw$$iw$$iw$$iw$$iw$MyAverage$, name: $outer, type: class $line24f4d7b3f7f54dfc89ae8e2757da4abf39.$read$$iw$$iw$$iw$$iw$$iw$$iw)
- object (class $line24f4d7b3f7f54dfc89ae8e2757da4abf39.$read$$iw$$iw$$iw$$iw$$iw$$iw$MyAverage$, $line24f4d7b3f7f54dfc89ae8e2757da4abf39.$read$$iw$$iw$$iw$$iw$$iw$$iw$MyAverage$@60a7eee1)
- field (class: org.apache.spark.sql.execution.aggregate.ComplexTypedAggregateExpression, name: aggregator, type: class org.apache.spark.sql.expressions.Aggregator)
- object (class org.apache.spark.sql.execution.aggregate.ComplexTypedAggregateExpression, MyAverage($line24f4d7b3f7f54dfc89ae8e2757da4abf39.$read$$iw$$iw$$iw$$iw$$iw$$iw$Employee))
- field (class: org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression, name: aggregateFunction, type: class org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction)
- object (class org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression, partial_myaverage($line24f4d7b3f7f54dfc89ae8e2757da4abf39.$read$$iw$$iw$$iw$$iw$$iw$$iw$MyAverage$@60a7eee1, Some(newInstance(class $line24f4d7b3f7f54dfc89ae8e2757da4abf39.$read$$iw$$iw$$iw$$iw$$iw$$iw$Employee)), Some(class $line24f4d7b3f7f54dfc89ae8e2757da4abf39.$read$$iw$$iw$$iw$$iw$$iw$$iw$Employee), Some(StructType(StructField(name,StringType,true),StructField(salary,LongType,false))), knownnotnull(assertnotnull(input[0, $line24f4d7b3f7f54dfc89ae8e2757da4abf39.$read$$iw$$iw$$iw$$iw$$iw$$iw$Average, true])).sum, knownnotnull(assertnotnull(input[0, $line24f4d7b3f7f54dfc89ae8e2757da4abf39.$read$$iw$$iw$$iw$$iw$$iw$$iw$Average, true])).count, newInstance(class $line24f4d7b3f7f54dfc89ae8e2757da4abf39.$read$$iw$$iw$$iw$$iw$$iw$$iw$Average), input[0, double, false], DoubleType, false, 0, 0) AS buf#308)
- writeObject data (class: scala.collection.immutable.List$SerializationProxy)
- object (class scala.collection.immutable.List$SerializationProxy, scala.collection.immutable.List$SerializationProxy@f939d16)
- writeReplace data (class: scala.collection.immutable.List$SerializationProxy)
- object (class scala.collection.immutable.$colon$colon, List(partial_myaverage($line24f4d7b3f7f54dfc89ae8e2757da4abf39.$read$$iw$$iw$$iw$$iw$$iw$$iw$MyAverage$@60a7eee1, Some(newInstance(class $line24f4d7b3f7f54dfc89ae8e2757da4abf39.$read$$iw$$iw$$iw$$iw$$iw$$iw$Employee)), Some(class $line24f4d7b3f7f54dfc89ae8e2757da4abf39.$read$$iw$$iw$$iw$$iw$$iw$$iw$Employee), Some(StructType(StructField(name,StringType,true),StructField(salary,LongType,false))), knownnotnull(assertnotnull(input[0, $line24f4d7b3f7f54dfc89ae8e2757da4abf39.$read$$iw$$iw$$iw$$iw$$iw$$iw$Average, true])).sum, knownnotnull(assertnotnull(input[0, $line24f4d7b3f7f54dfc89ae8e2757da4abf39.$read$$iw$$iw$$iw$$iw$$iw$$iw$Average, true])).count, newInstance(class $line24f4d7b3f7f54dfc89ae8e2757da4abf39.$read$$iw$$iw$$iw$$iw$$iw$$iw$Average), input[0, double, false], DoubleType, false, 0, 0) AS buf#308))
- field (class: org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec, name: aggregateExpressions, type: interface scala.collection.Seq)
- object (class org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec, ObjectHashAggregate(keys=[], functions=[partial_myaverage($line24f4d7b3f7f54dfc89ae8e2757da4abf39.$read$$iw$$iw$$iw$$iw$$iw$$iw$MyAverage$@60a7eee1, Some(newInstance(class $line24f4d7b3f7f54dfc89ae8e2757da4abf39.$read$$iw$$iw$$iw$$iw$$iw$$iw$Employee)), Some(class $line24f4d7b3f7f54dfc89ae8e2757da4abf39.$read$$iw$$iw$$iw$$iw$$iw$$iw$Employee), Some(StructType(StructField(name,StringType,true),StructField(salary,LongType,false))), knownnotnull(assertnotnull(input[0, $line24f4d7b3f7f54dfc89ae8e2757da4abf39.$read$$iw$$iw$$iw$$iw$$iw$$iw$Average, true])).sum, knownnotnull(assertnotnull(input[0, $line24f4d7b3f7f54dfc89ae8e2757da4abf39.$read$$iw$$iw$$iw$$iw$$iw$$iw$Average, true])).count, newInstance(class $line24f4d7b3f7f54dfc89ae8e2757da4abf39.$read$$iw$$iw$$iw$$iw$$iw$$iw$Average), input[0, double, false], DoubleType, false, 0, 0) AS buf#308], output=[buf#308])
+- LocalTableScan [name#274, salary#275]
)
- element of array (index: 0)
- array (class [Ljava.lang.Object;, size 6)
- field (class: java.lang.invoke.SerializedLambda, name: capturedArgs, type: class [Ljava.lang.Object;)
- object (class java.lang.invoke.SerializedLambda, SerializedLambda[capturingClass=class org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec, functionalInterfaceMethod=scala/Function2.apply:(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;, implementation=invokeStatic org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.$anonfun$doExecute$1$adapted:(Lorg/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec;ILorg/apache/spark/sql/execution/metric/SQLMetric;Lorg/apache/spark/sql/execution/metric/SQLMetric;Lorg/apache/spark/sql/execution/metric/SQLMetric;Lorg/apache/spark/sql/execution/metric/SQLMetric;Ljava/lang/Object;Lscala/collection/Iterator;)Lscala/collection/Iterator;, instantiatedMethodType=(Ljava/lang/Object;Lscala/collection/Iterator;)Lscala/collection/Iterator;, numCaptured=6])
- writeReplace data (class: java.lang.invoke.SerializedLambda)
- object (class org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec$$Lambda$5930/237479585, org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec$$Lambda$5930/237479585@6de6a3e)
- element of array (index: 0)
- array (class [Ljava.lang.Object;, size 1)
- field (class: java.lang.invoke.SerializedLambda, name: capturedArgs, type: class [Ljava.lang.Object;)
- object (class java.lang.invoke.SerializedLambda, SerializedLambda[capturingClass=class org.apache.spark.rdd.RDD, functionalInterfaceMethod=scala/Function3.apply:(Ljava/lang/Object;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;, implementation=invokeStatic org/apache/spark/rdd/RDD.$anonfun$mapPartitionsWithIndexInternal$2$adapted:(Lscala/Function2;Lorg/apache/spark/TaskContext;Ljava/lang/Object;Lscala/collection/Iterator;)Lscala/collection/Iterator;, instantiatedMethodType=(Lorg/apache/spark/TaskContext;Ljava/lang/Object;Lscala/collection/Iterator;)Lscala/collection/Iterator;, numCaptured=1])
- writeReplace data (class: java.lang.invoke.SerializedLambda)
- object (class org.apache.spark.rdd.RDD$$Lambda$5932/1340469986, org.apache.spark.rdd.RDD$$Lambda$5932/1340469986@7939a132)
- field (class: org.apache.spark.rdd.MapPartitionsRDD, name: f, type: interface scala.Function3)
- object (class org.apache.spark.rdd.MapPartitionsRDD, MapPartitionsRDD[20] at $anonfun$executeCollectResult$1 at FrameProfiler.scala:80)
- field (class: org.apache.spark.NarrowDependency, name: _rdd, type: class org.apache.spark.rdd.RDD)
- object (class org.apache.spark.OneToOneDependency, org.apache.spark.OneToOneDependency@1e0b1350)
- writeObject data (class: scala.collection.immutable.List$SerializationProxy)
- object (class scala.collection.immutable.List$SerializationProxy, scala.collection.immutable.List$SerializationProxy@29edc56a)
- writeReplace data (class: scala.collection.immutable.List$SerializationProxy)
- object (class scala.collection.immutable.$colon$colon, List(org.apache.spark.OneToOneDependency@1e0b1350))
- field (class: org.apache.spark.rdd.RDD, name: dependencies_, type: interface scala.collection.Seq)
- object (class org.apache.spark.rdd.MapPartitionsRDD, MapPartitionsRDD[21] at $anonfun$executeCollectResult$1 at FrameProfiler.scala:80)
- field (class: scala.Tuple2, name: _1, type: class java.lang.Object)
- object (class scala.Tuple2, (MapPartitionsRDD[21] at $anonfun$executeCollectResult$1 at FrameProfiler.scala:80,org.apache.spark.ShuffleDependency@567dc75c))
What am I missing?
I am running this script in DBR 11.0, with Spark 3.3.0, Scala 2.12
Upvotes: 0
Views: 266
Reputation: 7605
Applying the toColumn
inside the select()
fixed the problem:
val result = originalDF.as[Employee].select(MyAverage.toColumn.name("average_salary"))
result.show()
+--------------+
|average_salary|
+--------------+
| 3750.0|
+--------------+
Upvotes: 1