Reputation: 41
Spark 3.0 has deprecated UserDefinedAggregateFunction
and I was trying to rewrite my udaf using Aggregator
. Basic usage of Aggregator
is simple, however, I struggle with more generic version of the function.
I will try to explain my problem with this example, an implementation of collect_set
. It's not my actual case, but it's easier to explain the problem:
class CollectSetDemoAgg(name: String) extends Aggregator[Row, Set[Int], Set[Int]] {
override def zero = Set.empty
override def reduce(b: Set[Int], a: Row) = b + a.getInt(a.fieldIndex(name))
override def merge(b1: Set[Int], b2: Set[Int]) = b1 ++ b2
override def finish(reduction: Set[Int]) = reduction
override def bufferEncoder = Encoders.kryo[Set[Int]]
override def outputEncoder = ExpressionEncoder()
}
// using it:
df.agg(new CollectSetDemoAgg("rank").toColumn as "result").show()
I prefer .toColumn
vs .udf.register
, but it's not the point here.
Problem: I can not make universal version of this Aggregator, it will only work with integers.
I've attempted:
class CollectSetDemo(name: String) extends Aggregator[Row, Set[Any], Set[Any]]
It crashes with error:
No Encoder found for Any
- array element class: "java.lang.Object"
- root class: "scala.collection.immutable.Set"
java.lang.UnsupportedOperationException: No Encoder found for Any
- array element class: "java.lang.Object"
- root class: "scala.collection.immutable.Set"
at org.apache.spark.sql.catalyst.ScalaReflection$.$anonfun$serializerFor$1(ScalaReflection.scala:567)
I could not go with CollectSetDemo[T]
, case I was not able to proper outputEncoder
. Also, when using udaf, I can only work with Spark data types, columns, etc.
Upvotes: 3
Views: 959
Reputation: 3513
Apache DataFu-Spark has an example of this in its MultiArraySet UDAF. (disclosure: I am a member of DataFu and wrote this code).
The declaration looks like this:
/**
* Essentially the same as MultiSet, but gets an Array for input.
* There is an extra option to limit the number of keys (like @CountDistinctUpTo)
*/
class MultiArraySet[IN: Ordering : TypeTag](maxKeys: Int = -1)(implicit t: ClassTag[IN]) extends Aggregator[Array[IN], Map[IN, Int], Map[IN, Int]] with Serializable {
Unlike @Greg's answer, I used kryo encoding for the generic IN type.
implicit val inEncoder: Encoder[IN] = Encoders.kryo[IN]
def bufferEncoder: Encoder[Map[IN, Int]] = implicitly[Encoder[Map[IN, Int]]]
def outputEncoder: Encoder[Map[IN, Int]] = implicitly[Encoder[Map[IN, Int]]]
Upvotes: 0
Reputation: 589
Modification of @Ramunas answer with generics:
class CollectSetDemoAgg[T: TypeTag](name: String) extends Aggregator[Row, Set[T], Seq[T]] {
override def zero = Set.empty
override def reduce(b: Set[T], a: Row) = b + a.getAs[T](a.fieldIndex(name))
override def merge(b1: Set[T], b2: Set[T]) = b1 ++ b2
override def finish(reduction: Set[T]) = reduction.toSeq
override def bufferEncoder = Encoders.kryo[Set[T]]
override def outputEncoder = {
val tt = typeTag[Seq[T]]
val tpe = tt.in(mirror).tpe
val cls = mirror.runtimeClass(tpe)
val serializer = serializerForType(tpe)
val deserializer = deserializerForType(tpe)
new ExpressionEncoder[Seq[T]](serializer, deserializer, ClassTag[Seq[T]](cls))
}
}
Upvotes: 1
Reputation: 41
Have not found a nice way to solve the situation, but I was able to somewhat workaround it. Code was partially borrowed from RowEncoder
:
class CollectSetDemoAgg(name: String, fieldType: DataType) extends Aggregator[Row, Set[Any], Any] {
override def zero = Set.empty
override def reduce(b: Set[Any], a: Row) = b + a.get(a.fieldIndex(name))
override def merge(b1: Set[Any], b2: Set[Any]) = b1 ++ b2
override def finish(reduction: Set[Any]) = reduction.toSeq
override def bufferEncoder = Encoders.kryo[Set[Any]]
// now
override def outputEncoder = {
val mirror = ScalaReflection.mirror
val tt = fieldType match {
case ArrayType(LongType, _) => typeTag[Seq[Long]]
case ArrayType(IntegerType, _) => typeTag[Seq[Int]]
case ArrayType(StringType, _) => typeTag[Seq[String]]
// .. etc etc
case _ => throw new RuntimeException(s"Could not create encoder for ${name} column (${fieldType})")
}
val tpe = tt.in(mirror).tpe
val cls = mirror.runtimeClass(tpe)
val serializer = ScalaReflection.serializerForType(tpe)
val deserializer = ScalaReflection.deserializerForType(tpe)
new ExpressionEncoder[Any](serializer, deserializer, ClassTag[Any](cls))
}
}
One thing, that I had to add was result data type parameter in aggregator. The usage then changed to:
df.agg(new CollectSetDemoAgg("rank", new ArrayType(IntegerType, true)).toColumn as "result").show()
I really don't like how it turned out, but it works. I also welcome any suggestions how to improve it.
Upvotes: 1