Reputation: 428
I have a dataframe , which has a field of lists datatype that needs to be matched with cross-join, on condition that if any of the elements in the list exists in another list, then the both records should be considered a match.
Example.
import org.apache.spark.sql.functions.udf
val df = sc.parallelize(Seq(("one", List(1,34,3)), ("one", List(1,2,3)), ("two", List(1))))
.toDF("word", "count")
val lsEqual = (xs : (List[Int],List[Int])) => xs._1.find(xs._2.contains(_)).nonEmpty
val equalList = udf(lsEqual)
But this gives me the following error
val out = df.joinWith(df,equalList(df("count"),df("count")),"cross")
java.lang.ClassCastException: $anonfun$1 cannot be cast to scala.Function2
at org.apache.spark.sql.catalyst.expressions.ScalaUDF.<init>(ScalaUDF.scala:97)
at org.apache.spark.sql.expressions.UserDefinedFunction.apply(UserDefinedFunction.scala:56)
... 50 elided
Is there any other way to create custom predicates ?
Upvotes: 0
Views: 645
Reputation: 41957
Your lsEqual
function definition seems to be wrong. List
, Seq
, Array
are treated as WrappedArray
in Spark Dataframes. And you are passing two columns
to lsEqual
function, which should have been two variables.
The correct way should be
val lsEqual = (xs1 : scala.collection.mutable.WrappedArray[Int], xs2 : scala.collection.mutable.WrappedArray[Int]) => xs1.find(xs2.contains(_)).nonEmpty
which should definitely remove the error you are facing
Upvotes: 1