Yuehan Lyu
Yuehan Lyu

Reputation: 1094

Pass column and a Map to a Scala UDF

I come from Pyspark. I know how to do this in Pyspark but haven't managed to do the same thing in Scala.

Here is a dataframe,

val df = Seq(
  ("u1", Array[Int](2,3,4)),
  ("u2", Array[Int](7,8,9))
).toDF("id", "mylist")


//    +---+---------+
//    | id|   mylist|
//    +---+---------+
//    | u1|[2, 3, 4]|
//    | u2|[7, 8, 9]|
//    +---+---------+

and here is a Map object,

val myMap = (1 to 4).toList.map(x => (x,0)).toMap

//myMap: scala.collection.immutable.Map[Int,Int] = Map(1 -> 0, 2 -> 0, 3 -> 0, 4 -> 0)

so this map has key values from 1 to 4.

For each row of df, I want to check if any element in "mylist" is contained in myMap as a key value. If myMap contains an element, then return that element (return any one if multiple elements are contained), elsewise return -1.

So the result should look like

    +---+---------+-------+
    | id|   mylist| label|
    +---+---------+-------+
    | u1|[2, 3, 4]|    2  |
    | u2|[7, 8, 9]|    -1 |
    +---+---------+-------+

I have tried the following approaches:

  1. below function works for an array object, but does not work for a column:
def list2label(ls: Array[Int],
                m:  Map[Int, Int]):(Int) = {
                    var flag = 0
                    for (element <- ls) {
                        if (m.contains(element)) flag = element
                    }
                    flag
                }

val testls = Array[Int](2,3,4)
list2label(testls, myMap)

//testls: Array[Int] = Array(2, 3, 4)
//res33: Int = 4
  1. trying to use UDF, but got an error:
def list2label_udf(m: Map[Int, Int]) = udf( (ls: Array[Int]) =>(

                    var flag = 0
                    for (element <- ls) {
                        if (m.contains(element)) flag = element
                    }
                    flag
    )
)

//<console>:3: error: illegal start of simple expression
//                    var flag = 0
//                    ^

I think my udf is in wrong format..

  1. in Pyspark I can do this as I wish:
%pyspark

myDict={1:0, 2:0, 3:0, 4:0}

def list2label(ls, myDict):
    for i in ls:
        if i in dict3:
            return i
    return 0

def list2label_UDF(myDict):
     return udf(lambda c: list2label(c,myDict))

df = df.withColumn("label",list2label_UDF(myDict)(col("mylist")))

Any help would be appreciated!

Upvotes: 2

Views: 970

Answers (1)

Anand Sai
Anand Sai

Reputation: 1586

The solution is shown below:

  scala> df.show
+---+---------+
| id|   mylist|
+---+---------+
| u1|[2, 3, 4]|
| u2|[7, 8, 9]|
+---+---------+


scala> def customUdf(m: Map[Int,Int]) = udf((s: Seq[Int]) => {
          val intersection = s.toList.intersect(m.keys.toList)
          if(intersection.isEmpty) -1 else intersection(0)})

customUdf: (m: Map[Int,Int])org.apache.spark.sql.expressions.UserDefinedFunction

scala> df.select($"id", $"myList", customUdf(myMap)($"myList").as("new_col")).show
+---+---------+-------+
| id|   myList|new_col|
+---+---------+-------+
| u1|[2, 3, 4]|      2|
| u2|[7, 8, 9]|     -1|
+---+---------+-------+

Another approach could be to send list of keys of map instead of map itself as ypu are only checking on the keys. For this the solution is hown below:

scala> def customUdf1(m: List[Int]) = udf((s: Seq[Int]) => {
          val intersection = s.toList.intersect(m)
          if(intersection.isEmpty) -1 else intersection(0)})

customUdf1: (m: List[Int])org.apache.spark.sql.expressions.UserDefinedFunction

scala> df.select($"id",$"myList", customUdf1(myMap.keys.toList)($"myList").as("new_col")).show
+---+---------+-------+
| id|   myList|new_col|
+---+---------+-------+
| u1|[2, 3, 4]|      2|
| u2|[7, 8, 9]|     -1|
+---+---------+-------+

Let me know if it helps!!

Upvotes: 2

Related Questions