Reputation: 1094
I come from Pyspark. I know how to do this in Pyspark but haven't managed to do the same thing in Scala.
Here is a dataframe,
val df = Seq(
("u1", Array[Int](2,3,4)),
("u2", Array[Int](7,8,9))
).toDF("id", "mylist")
// +---+---------+
// | id| mylist|
// +---+---------+
// | u1|[2, 3, 4]|
// | u2|[7, 8, 9]|
// +---+---------+
and here is a Map object,
val myMap = (1 to 4).toList.map(x => (x,0)).toMap
//myMap: scala.collection.immutable.Map[Int,Int] = Map(1 -> 0, 2 -> 0, 3 -> 0, 4 -> 0)
so this map has key values from 1 to 4.
For each row of df, I want to check if any element in "mylist" is contained in myMap as a key value. If myMap contains an element, then return that element (return any one if multiple elements are contained), elsewise return -1.
So the result should look like
+---+---------+-------+
| id| mylist| label|
+---+---------+-------+
| u1|[2, 3, 4]| 2 |
| u2|[7, 8, 9]| -1 |
+---+---------+-------+
I have tried the following approaches:
def list2label(ls: Array[Int],
m: Map[Int, Int]):(Int) = {
var flag = 0
for (element <- ls) {
if (m.contains(element)) flag = element
}
flag
}
val testls = Array[Int](2,3,4)
list2label(testls, myMap)
//testls: Array[Int] = Array(2, 3, 4)
//res33: Int = 4
def list2label_udf(m: Map[Int, Int]) = udf( (ls: Array[Int]) =>(
var flag = 0
for (element <- ls) {
if (m.contains(element)) flag = element
}
flag
)
)
//<console>:3: error: illegal start of simple expression
// var flag = 0
// ^
I think my udf is in wrong format..
%pyspark
myDict={1:0, 2:0, 3:0, 4:0}
def list2label(ls, myDict):
for i in ls:
if i in dict3:
return i
return 0
def list2label_UDF(myDict):
return udf(lambda c: list2label(c,myDict))
df = df.withColumn("label",list2label_UDF(myDict)(col("mylist")))
Any help would be appreciated!
Upvotes: 2
Views: 970
Reputation: 1586
The solution is shown below:
scala> df.show
+---+---------+
| id| mylist|
+---+---------+
| u1|[2, 3, 4]|
| u2|[7, 8, 9]|
+---+---------+
scala> def customUdf(m: Map[Int,Int]) = udf((s: Seq[Int]) => {
val intersection = s.toList.intersect(m.keys.toList)
if(intersection.isEmpty) -1 else intersection(0)})
customUdf: (m: Map[Int,Int])org.apache.spark.sql.expressions.UserDefinedFunction
scala> df.select($"id", $"myList", customUdf(myMap)($"myList").as("new_col")).show
+---+---------+-------+
| id| myList|new_col|
+---+---------+-------+
| u1|[2, 3, 4]| 2|
| u2|[7, 8, 9]| -1|
+---+---------+-------+
Another approach could be to send list of keys of map instead of map itself as ypu are only checking on the keys. For this the solution is hown below:
scala> def customUdf1(m: List[Int]) = udf((s: Seq[Int]) => {
val intersection = s.toList.intersect(m)
if(intersection.isEmpty) -1 else intersection(0)})
customUdf1: (m: List[Int])org.apache.spark.sql.expressions.UserDefinedFunction
scala> df.select($"id",$"myList", customUdf1(myMap.keys.toList)($"myList").as("new_col")).show
+---+---------+-------+
| id| myList|new_col|
+---+---------+-------+
| u1|[2, 3, 4]| 2|
| u2|[7, 8, 9]| -1|
+---+---------+-------+
Let me know if it helps!!
Upvotes: 2