Hang Wu
Hang Wu

Reputation: 15

Determine the function signature of an anonymous function in scala

The context is to register a UserDefinedFunction(UDF) in spark, where the UDF is an anonymous function obtained via reflection. Since the function signature of the function is determined at runtime, I was wondering whether it is possible to do so.

Say the function impl() returns an anonymous function:

trait Base {}
class A extends Base{
  def impl(): Function1[Int, String] = new Function1[Int, String] {
    def apply(x: Int): String = "ab" + x.toString
  }
}
val classes = reflections.getSubTypesOf(classOf[Base]).toSet[Class[_ <: Base]].toList

and I obtain the anonymous function in another place:

val clazz = classes(0)
val instance = clazz.newInstance()
val impl = clazz.getDeclaredMethod("impl").invoke(instance)

Now, impl holds the anonymous function but I do not know its signature, and I'd like to ask whether we can convert it into a correct function instance:

impl.asInstanceOf[Function1[Int, String]]   // How to determine the function signature of the anonymous function, in this case Function1[Int, String]?

Since scala does not support generic function, I first consider getting the runtime type of the function:

import scala.reflect.runtime.universe.{TypeTag, typeTag}
def getTypeTag[T: TypeTag](obj: T) = typeTag[T]
val typeList = getTypeTag(impl).tpe.typeArgs

It will return List(Int, String), but I fail to recognize the correct function template via reflection.

Update: if the classes are defined as follows:

trait Base {}
class A extends Base{
  def impl(x: Int): String = {
    "ab" + x.toString
  }
}

where impl is the function itself and we do not know its function signature, can the impl function still be registered?

Upvotes: 1

Views: 387

Answers (1)

Dmytro Mitin
Dmytro Mitin

Reputation: 51703

The context is to register a UserDefinedFunction(UDF) in spark, where the UDF is an anonymous function obtained via reflection. Since the function signature of the function is determined at runtime, I was wondering whether it is possible to do so.

Normally you register a UDF as follows

import org.apache.spark.sql.SparkSession

object App {
  val spark = SparkSession.builder
    .master("local")
    .appName("Spark app")
    .getOrCreate()

  def impl(): Int => String = x => "ab" + x.toString

  spark.udf.register("foo", impl())

  def main(args: Array[String]): Unit = {
    spark.sql("""SELECT foo(10)""").show()
    //+-------+
    //|foo(10)|
    //+-------+
    //|   ab10|
    //+-------+
  }
}

The signature of register is

def register[RT: TypeTag, A1: TypeTag](name: String, func: Function1[A1, RT]): UserDefinedFunction

aka

def register[RT, A1](name: String, func: Function1[A1, RT])(implicit
  ttag:  TypeTag[RT],
  ttag1: TypeTag[A1]
): UserDefinedFunction

What TypeTag normally does is persisting a type information from compile time to runtime.

So in order to call register you either have to know types at compile time or have to know how to construct type tags at runtime.

If you don't have access to how impl() is constructed at runtime and you don't have (at least at runtime) the information about types/type tags at all then unfortunately this type information is irreversibly lost because of the type erasure (Function1[Int, String] is just Function1[_,_] at runtime)

def impl(): Any = (x: Int) => "ab" + x.toString

But it's possible that you have access to how impl() is constructed at runtime and you know (at least at runtime) the information about types/type tags. So I assume that you don't have types Int, String statically and you can't call typeTag[Int], typeTag[String] (as I do below) but you have somehow runtime objects of Type/TypeTag

import org.apache.spark.sql.catalyst.ScalaReflection.universe._

def impl(): Any = (x: Int) => "ab" + x.toString
val ttag1 = typeTag[Int]    // actual definition is probably different
val ttag  = typeTag[String] // actual definition is probably different

In such case you can call register resolving implicits explicitly

spark.udf.register("foo", impl().asInstanceOf[Function1[_,_]])(ttag.asInstanceOf[TypeTag[_]], ttag1.asInstanceOf[TypeTag[_]])

Well, this doesn't compile because of existential types but you can trick the compiler

type A
type B
spark.udf.register("foo", impl().asInstanceOf[A => B])(ttag.asInstanceOf[TypeTag[B]], ttag1.asInstanceOf[TypeTag[A]])

https://gist.github.com/DmytroMitin/0b3660d646f74fb109665bad41b3ae9f

Alternatively you can use runtime compilation (creating a new compile time inside the runtime)

import org.apache.spark.sql.catalyst.ScalaReflection
import ScalaReflection.universe._
import scala.tools.reflect.ToolBox // libraryDependencies += scalaOrganization.value % "scala-compiler" % scalaVersion.value

val rm = ScalaReflection.mirror
val tb = rm.mkToolBox()
tb.eval(q"""App.spark.udf.register("foo", App.impl().asInstanceOf[$ttag1 => $ttag])""")

https://gist.github.com/DmytroMitin/5b5dd4d7db0d0eebb51dd8c16735e0fb

You should provide some code how you construct impl() and we'll see whether it's possible to restore the types.

Spark registered a Scala object all of the methods as a UDF

scala cast object based on reflection symbol


Update. After you get val impl = clazz.getDeclaredMethod("impl").invoke(instance) it's too late to restore function types (you can check that typeList is empty). Where function type (or type tag) should be captured is somewhere not too far from class A, maybe inside A or outside A but when Int, String are not lost yet. What TypeTag can do is persisting type information from compile time to runtime, it can't restore type information at runtime if it's lost.

import org.apache.spark.sql.catalyst.ScalaReflection
import ScalaReflection.universe._
import org.apache.spark.sql.SparkSession
import org.reflections.Reflections
import scala.jdk.CollectionConverters._
import scala.reflect.api

object App {
  def getType[T: TypeTag](obj: T) = typeOf[T]

  trait Base
  class A extends Base {
    def impl(): Int => String = x => "ab" + x.toString 

       // NotSerializableException
    //def impl(): Function1[Int, String] = new Function1[Int, String] {
    //  def apply(x: Int): String = "ab" + x.toString
    //}

    val tpe = getType(impl())
  }

  val reflections = new Reflections()
  val classes: List[Class[_ <: Base]] = reflections.getSubTypesOf(classOf[Base]).asScala.toList

  val clazz = classes(0)
  val instance = clazz.newInstance()
  val impl = clazz.getDeclaredMethod("impl").invoke(instance)
  val functionType = clazz.getDeclaredMethod("tpe").invoke(instance).asInstanceOf[Type]
  val List(argType, returnType) = functionType.typeArgs

  val spark = SparkSession.builder()
    .master("local")
    .appName("Spark app")
    .getOrCreate()

  val rm = ScalaReflection.mirror

  // (*)
  def typeToTypeTag[T](tpe: Type): TypeTag[T] =
    TypeTag(rm, new api.TypeCreator {
      def apply[U <: api.Universe with Singleton](m: api.Mirror[U]) =
        tpe.asInstanceOf[U#Type]
    })

//  type X
//  type Y
//  spark.udf.register("foo", impl.asInstanceOf[X => Y])(
//    typeToTypeTag[Y](returnType),
//    typeToTypeTag[X](argType)
//  )

  impl match {
    case impl: Function1[x, y] => spark.udf.register("foo", impl)(
      typeToTypeTag[y](returnType),
      typeToTypeTag[x](argType)
    )
  }

  def main(args: Array[String]): Unit = {
    spark.sql("""SELECT foo(10)""").show()
  }

}

https://gist.github.com/DmytroMitin/2ebfae922f8a467d01b6ef18c8b8e5ad

(*) Get a TypeTag from a Type?

Now spark.sql("""SELECT foo(10)""").show() throws java.io.NotSerializableException but I guess it's not related to reflection.

Alternatively you can use runtime compilation (instead of manual resolution of implicits and construction of type tags from types)

import scala.tools.reflect.ToolBox

val rm = ScalaReflection.mirror
val tb = rm.mkToolBox()
tb.eval(q"""App.spark.udf.register("foo", App.impl.asInstanceOf[$functionType])""")

https://gist.github.com/DmytroMitin/ba469faeca2230890845e1532b36e2a1

One more option is to request the return type of method impl() as soon as we get class A (outside A)

class A extends Base {
  def impl(): Int => String = x => "ab" + x.toString
}

// ...
val functionType = rm.classSymbol(clazz).typeSignature.decl(TermName("impl")).asMethod.returnType
val List(argType, returnType) = functionType.typeArgs

https://gist.github.com/DmytroMitin/3bd2c19d158f8241a80952c397ee5e09


Update 2. If the methods are defined as follows:

class A extends Base{
  def impl(x: Int): String = {
    "ab" + x.toString
  }
}

then runtime compilation normally should be

val rm = ScalaReflection.mirror
val classSymbol = rm.classSymbol(clazz)
val tb = rm.mkToolBox()

tb.eval(q"""App.spark.udf.register("foo", (new $classSymbol).$methodSymbol(_))""")

or

tb.eval(q"""App.spark.udf.register("foo", (new $classSymbol).impl(_))""")

but now with Spark it produces ClassCastException: cannot assign instance of java.lang.invoke.SerializedLambda to field org.apache.spark.sql.catalyst.expressions.ScalaUDF.f of type scala.Function1 in instance of org.apache.spark.sql.catalyst.expressions.ScalaUDF similarly to Spark registered a Scala object all of the methods as a UDF

https://gist.github.com/DmytroMitin/b0f110f4cf15e2dfd4add70f7124a7b6

But ordinary Scala runtime reflection seems to work

val rm = ScalaReflection.mirror
val classSymbol = rm.classSymbol(clazz)
val methodSymbol = classSymbol.typeSignature.decl(TermName("impl")).asMethod
val returnType = methodSymbol.returnType
val argType = methodSymbol.paramLists.head.head.typeSignature

val constructorSymbol = classSymbol.typeSignature.decl(termNames.CONSTRUCTOR).asMethod
val instance = rm.reflectClass(classSymbol).reflectConstructor(constructorSymbol)()
val impl: Any => Any = rm.reflect(instance).reflectMethod(methodSymbol)(_)

def typeToTypeTag[T](tpe: Type): TypeTag[T] =
  TypeTag(rm, new api.TypeCreator {
    def apply[U <: api.Universe with Singleton](m: api.Mirror[U]) =
      tpe.asInstanceOf[U#Type]
  })

impl match {
  case impl: Function1[x, y] => spark.udf.register("foo", impl)(
    typeToTypeTag[y](returnType),
    typeToTypeTag[x](argType)
  )
}

https://gist.github.com/DmytroMitin/763751096fe9cdb2e0d18ae4b9290a54


Update 3. One more approach is to use compile-time reflection (macros) rather than runtime reflection if you have enough information at compile time (e.g. if all the classes are known at compile time)

import scala.collection.mutable
import scala.language.experimental.macros
import scala.reflect.macros.blackbox

object Macros {
  def registerMethod[A](): Unit = macro registerMethodImpl[A]

  def registerMethodImpl[A: c.WeakTypeTag](c: blackbox.Context)(): c.Tree = {
    import c.universe._
    val A = weakTypeOf[A]

    var children = mutable.Seq[Type]()

    val traverser = new Traverser {
      override def traverse(tree: Tree): Unit = {
        tree match {
          case _: ClassDef =>
            val tpe = tree.symbol.asClass.toType
            if (tpe <:< A && !(tpe =:= A)) children :+= tpe
          case _ =>
        }

        super.traverse(tree)
      }
    }

    c.enclosingRun.units.foreach(unit => traverser.traverse(unit.body))

    val calls = children.map(tpe =>
      q"""spark.udf.register("foo", (new $tpe).impl(_))"""
    )

    q"..$calls"
  }
}
// in a different subproject

import org.apache.spark.sql.SparkSession

object App {
  trait Base

  class A extends Base {
    def impl(x: Int): String = "ab" + x.toString
  }

  val spark = SparkSession.builder()
    .master("local")
    .appName("Spark app")
    .getOrCreate()

  Macros.registerMethod[Base]()

  def main(args: Array[String]): Unit = {
    spark.sql("""SELECT foo(10)""").show()
  }
}

https://gist.github.com/DmytroMitin/6623f1f900330f8341f209e1347a0007

Shapeless - How to derive LabelledGeneric for Coproduct (KnownSubclasses)


Update 4. If we replace val clazz = classes.head with classes.foreach(clazz => ... then issues with NotSerializableException can be fixed with inlining

import scala.language.experimental.macros
import scala.reflect.macros.blackbox

object Macros {
  def registerMethod(clazz: Class[_]): Unit = macro registerMethodImpl

  def registerMethodImpl(c: blackbox.Context)(clazz: c.Tree): c.Tree = {
    import c.universe._

    val ScalaReflection = q"_root_.org.apache.spark.sql.catalyst.ScalaReflection"
    val rm = q"$ScalaReflection.mirror"
    val ru = q"$ScalaReflection.universe"
    val classSymbol = q"$rm.classSymbol($clazz)"
    val methodSymbol = q"""$classSymbol.typeSignature.decl($ru.TermName("impl")).asMethod"""
    val returnType = q"$methodSymbol.returnType"
    val argType = q"$methodSymbol.paramLists.head.head.typeSignature"

    val constructorSymbol = q"$classSymbol.typeSignature.decl($ru.termNames.CONSTRUCTOR).asMethod"
    val instance = q"$rm.reflectClass($classSymbol).reflectConstructor($constructorSymbol).apply()"
    val impl1 = q"(x: Any) => $rm.reflect($instance).reflectMethod($methodSymbol).apply(x)"
    val api = q"_root_.scala.reflect.api"

    def typeToTypeTag(T: Tree, tpe: Tree): Tree =
      q"""
        $ru.TypeTag[$T]($rm, new $api.TypeCreator {
          override def apply[U <: $api.Universe with _root_.scala.Singleton](m: $api.Mirror[U]) =
            $tpe.asInstanceOf[U#Type]
        })
      """

    val impl2 = TermName(c.freshName("impl2"))
    val x = TypeName(c.freshName("x"))
    val y = TypeName(c.freshName("y"))
    q"""
      $impl1 match {
        case $impl2: _root_.scala.Function1[$x, $y] => spark.udf.register("foo", $impl2)(
          ${typeToTypeTag(tq"$y", returnType)},
          ${typeToTypeTag(tq"$x", argType)}
        )
      }
    """
  }
}
// in a different subproject

import org.apache.spark.sql.SparkSession
import org.reflections.Reflections
import scala.jdk.CollectionConverters._

trait Base
class A extends Base /*with Serializable*/ {
  def impl(x: Int): String = "ab" + x.toString
}

object App {
  val spark: SparkSession = SparkSession.builder()
    .master("local")
    .appName("Spark app")
    .getOrCreate()

  val reflections = new Reflections()
  val classes: List[Class[_ <: Base]] = reflections.getSubTypesOf(classOf[Base]).asScala.toList

  classes.foreach(clazz =>
    Macros.registerMethod(clazz)
  )

  def main(args: Array[String]): Unit = {
    spark.sql("""SELECT foo(10)""").show()
  }
}

https://gist.github.com/DmytroMitin/c926158a9ff94a6539097c603bbedf6a

Upvotes: 1

Related Questions