Ignacio Alorre
Ignacio Alorre

Reputation: 7605

NotSerializableException: org.apache.spark.sql.TypedColumn when calling a UDAFs

I am trying to reproduce the User Defined Aggregate Functions example provided at Spark SQL Guide.

The only change I am adding with respect of the original code is the DataFrame creation:

import org.apache.spark.sql.{Encoder, Encoders, SparkSession}
import org.apache.spark.sql.expressions.Aggregator

case class Employee(name: String, salary: Long)
case class Average(var sum: Long, var count: Long)

object MyAverage extends Aggregator[Employee, Average, Double] {
  // A zero value for this aggregation. Should satisfy the property that any b + zero = b
  def zero: Average = Average(0L, 0L)
  // Combine two values to produce a new value. For performance, the function may modify `buffer`
  // and return it instead of constructing a new object
  def reduce(buffer: Average, employee: Employee): Average = {
    buffer.sum += employee.salary
    buffer.count += 1
    buffer
  }
  // Merge two intermediate values
  def merge(b1: Average, b2: Average): Average = {
    b1.sum += b2.sum
    b1.count += b2.count
    b1
  }
  // Transform the output of the reduction
  def finish(reduction: Average): Double = reduction.sum.toDouble / reduction.count
  // Specifies the Encoder for the intermediate value type
  def bufferEncoder: Encoder[Average] = Encoders.product
  // Specifies the Encoder for the final output value type
  def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}

val originalDF = Seq(
     ("Michael", 3000),
     ("Andy", 4500),
     ("Justin", 3500),
     ("Berta", 4000)
   ).toDF("name", "salary")

+-------+------+
|name   |salary|
+-------+------+
|Michael|3000  |
|Andy   |4500  |
|Justin |3500  |
|Berta  |4000  |
+-------+------+

When I try to use this UDAFs with Spark SQL (Second option the documentation)

spark.udf.register("myAverage", functions.udaf(MyAverage))

originalDF.createOrReplaceTempView("employees")

val result = spark.sql("SELECT myAverage(salary) as average_salary FROM employees")

result.show()

Everything goes as expected:

+--------------+
|average_salary|
+--------------+
|        3750.0|
+--------------+

However, when I try to use the approach which converts the function to a TypedColumn:

val averageSalary = MyAverage.toColumn.name("average_salary")
val result = originalDF.as[Employee].select(averageSalary)
result.show()

I am getting the following Exception:

Job aborted due to stage failure.
Caused by: NotSerializableException: org.apache.spark.sql.TypedColumn
Serialization stack:
    - object not serializable (class: org.apache.spark.sql.TypedColumn, value: myaverage(knownnotnull(assertnotnull(input[0, $line24f4d7b3f7f54dfc89ae8e2757da4abf39.$read$$iw$$iw$$iw$$iw$$iw$$iw$Average, true])).sum AS sum, knownnotnull(assertnotnull(input[0, $line24f4d7b3f7f54dfc89ae8e2757da4abf39.$read$$iw$$iw$$iw$$iw$$iw$$iw$Average, true])).count AS count, newInstance(class $line24f4d7b3f7f54dfc89ae8e2757da4abf39.$read$$iw$$iw$$iw$$iw$$iw$$iw$Average), boundreference()) AS average_salary)
    - field (class: $line24f4d7b3f7f54dfc89ae8e2757da4abf39.$read$$iw$$iw$$iw$$iw$$iw$$iw, name: averageSalary, type: class org.apache.spark.sql.TypedColumn)
    - object (class $line24f4d7b3f7f54dfc89ae8e2757da4abf39.$read$$iw$$iw$$iw$$iw$$iw$$iw, $line24f4d7b3f7f54dfc89ae8e2757da4abf39.$read$$iw$$iw$$iw$$iw$$iw$$iw@1254d4c6)
    - field (class: $line24f4d7b3f7f54dfc89ae8e2757da4abf39.$read$$iw$$iw$$iw$$iw$$iw$$iw$MyAverage$, name: $outer, type: class $line24f4d7b3f7f54dfc89ae8e2757da4abf39.$read$$iw$$iw$$iw$$iw$$iw$$iw)
    - object (class $line24f4d7b3f7f54dfc89ae8e2757da4abf39.$read$$iw$$iw$$iw$$iw$$iw$$iw$MyAverage$, $line24f4d7b3f7f54dfc89ae8e2757da4abf39.$read$$iw$$iw$$iw$$iw$$iw$$iw$MyAverage$@60a7eee1)
    - field (class: org.apache.spark.sql.execution.aggregate.ComplexTypedAggregateExpression, name: aggregator, type: class org.apache.spark.sql.expressions.Aggregator)
    - object (class org.apache.spark.sql.execution.aggregate.ComplexTypedAggregateExpression, MyAverage($line24f4d7b3f7f54dfc89ae8e2757da4abf39.$read$$iw$$iw$$iw$$iw$$iw$$iw$Employee))
    - field (class: org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression, name: aggregateFunction, type: class org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction)
    - object (class org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression, partial_myaverage($line24f4d7b3f7f54dfc89ae8e2757da4abf39.$read$$iw$$iw$$iw$$iw$$iw$$iw$MyAverage$@60a7eee1, Some(newInstance(class $line24f4d7b3f7f54dfc89ae8e2757da4abf39.$read$$iw$$iw$$iw$$iw$$iw$$iw$Employee)), Some(class $line24f4d7b3f7f54dfc89ae8e2757da4abf39.$read$$iw$$iw$$iw$$iw$$iw$$iw$Employee), Some(StructType(StructField(name,StringType,true),StructField(salary,LongType,false))), knownnotnull(assertnotnull(input[0, $line24f4d7b3f7f54dfc89ae8e2757da4abf39.$read$$iw$$iw$$iw$$iw$$iw$$iw$Average, true])).sum, knownnotnull(assertnotnull(input[0, $line24f4d7b3f7f54dfc89ae8e2757da4abf39.$read$$iw$$iw$$iw$$iw$$iw$$iw$Average, true])).count, newInstance(class $line24f4d7b3f7f54dfc89ae8e2757da4abf39.$read$$iw$$iw$$iw$$iw$$iw$$iw$Average), input[0, double, false], DoubleType, false, 0, 0) AS buf#308)
    - writeObject data (class: scala.collection.immutable.List$SerializationProxy)
    - object (class scala.collection.immutable.List$SerializationProxy, scala.collection.immutable.List$SerializationProxy@f939d16)
    - writeReplace data (class: scala.collection.immutable.List$SerializationProxy)
    - object (class scala.collection.immutable.$colon$colon, List(partial_myaverage($line24f4d7b3f7f54dfc89ae8e2757da4abf39.$read$$iw$$iw$$iw$$iw$$iw$$iw$MyAverage$@60a7eee1, Some(newInstance(class $line24f4d7b3f7f54dfc89ae8e2757da4abf39.$read$$iw$$iw$$iw$$iw$$iw$$iw$Employee)), Some(class $line24f4d7b3f7f54dfc89ae8e2757da4abf39.$read$$iw$$iw$$iw$$iw$$iw$$iw$Employee), Some(StructType(StructField(name,StringType,true),StructField(salary,LongType,false))), knownnotnull(assertnotnull(input[0, $line24f4d7b3f7f54dfc89ae8e2757da4abf39.$read$$iw$$iw$$iw$$iw$$iw$$iw$Average, true])).sum, knownnotnull(assertnotnull(input[0, $line24f4d7b3f7f54dfc89ae8e2757da4abf39.$read$$iw$$iw$$iw$$iw$$iw$$iw$Average, true])).count, newInstance(class $line24f4d7b3f7f54dfc89ae8e2757da4abf39.$read$$iw$$iw$$iw$$iw$$iw$$iw$Average), input[0, double, false], DoubleType, false, 0, 0) AS buf#308))
    - field (class: org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec, name: aggregateExpressions, type: interface scala.collection.Seq)
    - object (class org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec, ObjectHashAggregate(keys=[], functions=[partial_myaverage($line24f4d7b3f7f54dfc89ae8e2757da4abf39.$read$$iw$$iw$$iw$$iw$$iw$$iw$MyAverage$@60a7eee1, Some(newInstance(class $line24f4d7b3f7f54dfc89ae8e2757da4abf39.$read$$iw$$iw$$iw$$iw$$iw$$iw$Employee)), Some(class $line24f4d7b3f7f54dfc89ae8e2757da4abf39.$read$$iw$$iw$$iw$$iw$$iw$$iw$Employee), Some(StructType(StructField(name,StringType,true),StructField(salary,LongType,false))), knownnotnull(assertnotnull(input[0, $line24f4d7b3f7f54dfc89ae8e2757da4abf39.$read$$iw$$iw$$iw$$iw$$iw$$iw$Average, true])).sum, knownnotnull(assertnotnull(input[0, $line24f4d7b3f7f54dfc89ae8e2757da4abf39.$read$$iw$$iw$$iw$$iw$$iw$$iw$Average, true])).count, newInstance(class $line24f4d7b3f7f54dfc89ae8e2757da4abf39.$read$$iw$$iw$$iw$$iw$$iw$$iw$Average), input[0, double, false], DoubleType, false, 0, 0) AS buf#308], output=[buf#308])
+- LocalTableScan [name#274, salary#275]
)
    - element of array (index: 0)
    - array (class [Ljava.lang.Object;, size 6)
    - field (class: java.lang.invoke.SerializedLambda, name: capturedArgs, type: class [Ljava.lang.Object;)
    - object (class java.lang.invoke.SerializedLambda, SerializedLambda[capturingClass=class org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec, functionalInterfaceMethod=scala/Function2.apply:(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;, implementation=invokeStatic org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.$anonfun$doExecute$1$adapted:(Lorg/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec;ILorg/apache/spark/sql/execution/metric/SQLMetric;Lorg/apache/spark/sql/execution/metric/SQLMetric;Lorg/apache/spark/sql/execution/metric/SQLMetric;Lorg/apache/spark/sql/execution/metric/SQLMetric;Ljava/lang/Object;Lscala/collection/Iterator;)Lscala/collection/Iterator;, instantiatedMethodType=(Ljava/lang/Object;Lscala/collection/Iterator;)Lscala/collection/Iterator;, numCaptured=6])
    - writeReplace data (class: java.lang.invoke.SerializedLambda)
    - object (class org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec$$Lambda$5930/237479585, org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec$$Lambda$5930/237479585@6de6a3e)
    - element of array (index: 0)
    - array (class [Ljava.lang.Object;, size 1)
    - field (class: java.lang.invoke.SerializedLambda, name: capturedArgs, type: class [Ljava.lang.Object;)
    - object (class java.lang.invoke.SerializedLambda, SerializedLambda[capturingClass=class org.apache.spark.rdd.RDD, functionalInterfaceMethod=scala/Function3.apply:(Ljava/lang/Object;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;, implementation=invokeStatic org/apache/spark/rdd/RDD.$anonfun$mapPartitionsWithIndexInternal$2$adapted:(Lscala/Function2;Lorg/apache/spark/TaskContext;Ljava/lang/Object;Lscala/collection/Iterator;)Lscala/collection/Iterator;, instantiatedMethodType=(Lorg/apache/spark/TaskContext;Ljava/lang/Object;Lscala/collection/Iterator;)Lscala/collection/Iterator;, numCaptured=1])
    - writeReplace data (class: java.lang.invoke.SerializedLambda)
    - object (class org.apache.spark.rdd.RDD$$Lambda$5932/1340469986, org.apache.spark.rdd.RDD$$Lambda$5932/1340469986@7939a132)
    - field (class: org.apache.spark.rdd.MapPartitionsRDD, name: f, type: interface scala.Function3)
    - object (class org.apache.spark.rdd.MapPartitionsRDD, MapPartitionsRDD[20] at $anonfun$executeCollectResult$1 at FrameProfiler.scala:80)
    - field (class: org.apache.spark.NarrowDependency, name: _rdd, type: class org.apache.spark.rdd.RDD)
    - object (class org.apache.spark.OneToOneDependency, org.apache.spark.OneToOneDependency@1e0b1350)
    - writeObject data (class: scala.collection.immutable.List$SerializationProxy)
    - object (class scala.collection.immutable.List$SerializationProxy, scala.collection.immutable.List$SerializationProxy@29edc56a)
    - writeReplace data (class: scala.collection.immutable.List$SerializationProxy)
    - object (class scala.collection.immutable.$colon$colon, List(org.apache.spark.OneToOneDependency@1e0b1350))
    - field (class: org.apache.spark.rdd.RDD, name: dependencies_, type: interface scala.collection.Seq)
    - object (class org.apache.spark.rdd.MapPartitionsRDD, MapPartitionsRDD[21] at $anonfun$executeCollectResult$1 at FrameProfiler.scala:80)
    - field (class: scala.Tuple2, name: _1, type: class java.lang.Object)
    - object (class scala.Tuple2, (MapPartitionsRDD[21] at $anonfun$executeCollectResult$1 at FrameProfiler.scala:80,org.apache.spark.ShuffleDependency@567dc75c))

What am I missing?

I am running this script in DBR 11.0, with Spark 3.3.0, Scala 2.12

Upvotes: 0

Views: 266

Answers (1)

Ignacio Alorre
Ignacio Alorre

Reputation: 7605

Applying the toColumn inside the select() fixed the problem:

val result = originalDF.as[Employee].select(MyAverage.toColumn.name("average_salary"))
result.show()

+--------------+
|average_salary|
+--------------+
|        3750.0|
+--------------+

Upvotes: 1

Related Questions