stack0114106
stack0114106

Reputation: 8711

spark - stack multiple when conditions from an Array of column expressions

I have the below spark dataframe:

val df = Seq(("US",10),("IND",20),("NZ",30),("CAN",40)).toDF("a","b")
df.show(false)
+---+---+
|a  |b  |
+---+---+
|US |10 |
|IND|20 |
|NZ |30 |
|CAN|40 |
+---+---+

and I'm applying the when() condition as follows:

df.withColumn("x", when(col("a").isin(us_list:_*),"u").when(col("a").isin(i_list:_*),"i").when(col("a").isin(n_list:_*),"n").otherwise("-")).show(false)

+---+---+---+
|a  |b  |x  |
+---+---+---+
|US |10 |u  |
|IND|20 |i  |
|NZ |30 |n  |
|CAN|40 |-  |
+---+---+---+

Now to minimize the code, I'm trying the below:

val us_list = Array("U","US")
val i_list = Array("I","IND")
val n_list = Array("N","NZ")
val ar1 = Array((us_list,"u"),(i_list,"i"),(n_list,"n"))

val ap = ar1.map( x => when(col("a").isInCollection(x._1),x._2) )

This results in

ap: Array[org.apache.spark.sql.Column] = Array(CASE WHEN (a IN (U, US)) THEN u END, CASE WHEN (a IN (I, IND)) THEN i END, CASE WHEN (a IN (N, NZ)) THEN n END)

but when I try

val ap = ar1.map( x => when(col("a").isInCollection(x._1),x._2) ).reduce( (x,y) => x.y )

I get an error. How to fix this?

Upvotes: 1

Views: 865

Answers (2)

mck
mck

Reputation: 42392

There is usually no need to combine when statements using reduce/fold etc. coalesce is enough because the when statements are evaluated in sequence, and gives null when the condition is false. Also it can save you from specifying otherwise because you can just append one more column to the list of arguments to coalesce.

val ar1 = Array((us_list,"u"),(i_list,"i"),(n_list,"n"))
val ap = ar1.map( x => when(col("a").isInCollection(x._1),x._2) )
val combined = coalesce(ap :+ lit("-"): _*)

df.withColumn("x", combined).show
+---+---+---+
|  a|  b|  x|
+---+---+---+
| US| 10|  u|
|IND| 20|  i|
| NZ| 30|  n|
|CAN| 40|  -|
+---+---+---+

Upvotes: 1

blackbishop
blackbishop

Reputation: 32680

You can use foldLeft on ar1 list :

val x = ar1.foldLeft(lit("-")) { case (acc, (list, value)) =>
  when(col("a").isin(list: _*), value).otherwise(acc)
}

// x: org.apache.spark.sql.Column = CASE WHEN (a IN (N, NZ)) THEN n ELSE CASE WHEN (a IN (I, IND)) THEN i ELSE CASE WHEN (a IN (U, US)) THEN u ELSE - END END END

Upvotes: 1

Related Questions