Shad Amez
Shad Amez

Reputation: 428

How to create UDF in Spark to support custom predicate

I have a dataframe , which has a field of lists datatype that needs to be matched with cross-join, on condition that if any of the elements in the list exists in another list, then the both records should be considered a match.

Example.

import org.apache.spark.sql.functions.udf


val df = sc.parallelize(Seq(("one", List(1,34,3)), ("one", List(1,2,3)), ("two", List(1))))
          .toDF("word", "count")

val lsEqual = (xs : (List[Int],List[Int])) => xs._1.find(xs._2.contains(_)).nonEmpty
 val equalList = udf(lsEqual)

But this gives me the following error

val out =  df.joinWith(df,equalList(df("count"),df("count")),"cross")
java.lang.ClassCastException: $anonfun$1 cannot be cast to scala.Function2
at org.apache.spark.sql.catalyst.expressions.ScalaUDF.<init>(ScalaUDF.scala:97)
at org.apache.spark.sql.expressions.UserDefinedFunction.apply(UserDefinedFunction.scala:56)
... 50 elided

Is there any other way to create custom predicates ?

Upvotes: 0

Views: 645

Answers (1)

Ramesh Maharjan
Ramesh Maharjan

Reputation: 41957

Your lsEqual function definition seems to be wrong. List, Seq, Array are treated as WrappedArray in Spark Dataframes. And you are passing two columns to lsEqual function, which should have been two variables.

The correct way should be

val lsEqual = (xs1 : scala.collection.mutable.WrappedArray[Int], xs2 : scala.collection.mutable.WrappedArray[Int]) => xs1.find(xs2.contains(_)).nonEmpty

which should definitely remove the error you are facing

Upvotes: 1

Related Questions