czr_RR
czr_RR

Reputation: 591

Conditional Spark map() function based on input columns

What I'm trying to achieve here is sending to Spark SQL map function conditionally generated columns depending on if they have null, 0 or any other value I may want.

Take for example this initial DF.

val initialDF = Seq(
  ("a", "b", 1), 
  ("a", "b", null), 
  ("a", null, 0)
).toDF("field1", "field2", "field3")

From that initial DataFrame I want to generate yet another column which will be a map, like this.

initialDF.withColumn("thisMap", MY_FUNCTION)

My current approach to this is basically take a Seq[String] in a method a flatMap the key-value pairs that the Spark SQL method receives, like this.

def toMap(columns: String*): Column = {
  map(
    columns.flatMap(column => List(lit(column), col(column))): _*
  )
}

But then, filtering becomes a Scala thing and is quite a mess.

What I would like to obtain after the processing would be, for each of those rows, the next DataFrame.

val initialDF = Seq(
  ("a", "b", 1, Map("field1" -> "a", "field2" -> "b", "field3" -> 1)),
  ("a", "b", null, Map("field1" -> "a", "field2" -> "b")),
  ("a", null, 0, Map("field1" -> "a"))
)
  .toDF("field1", "field2", "field3", "thisMap")

I was wondering if this can be achieved using the Column API which is way more intuitive with .isNull or .equalTo?

Upvotes: 1

Views: 2810

Answers (2)

Fqp
Fqp

Reputation: 143

Here's a small improvement on Lamanus' answer above which only loops over df.columns once:

import org.apache.spark.sql._
import org.apache.spark.sql.functions._

case class Record(field1: String, field2: String, field3: java.lang.Integer)

val df = Seq(
  Record("a", "b", 1),
  Record("a", "b", null),
  Record("a", null, 0)
).toDS

df.show

// +------+------+------+
// |field1|field2|field3|
// +------+------+------+
// |     a|     b|     1|
// |     a|     b|  null|
// |     a|  null|     0|
// +------+------+------+

df.withColumn("thisMap", map_concat(
    df.columns.map { colName => 
        when(col(colName).isNull or col(colName) === 0, map())
        .otherwise(map(lit(colName), col(colName)))
    }: _*
)).show(false)

// +------+------+------+---------------------------------------+
// |field1|field2|field3|thisMap                                |
// +------+------+------+---------------------------------------+
// |a     |b     |1     |[field1 -> a, field2 -> b, field3 -> 1]|
// |a     |b     |null  |[field1 -> a, field2 -> b]             |
// |a     |null  |0     |[field1 -> a]                          |
// +------+------+------+---------------------------------------+

Upvotes: 2

Lamanus
Lamanus

Reputation: 13581

UPDATE

I found a way to achieve the expected result but it is a bit dirty.

val df2 = df.columns.foldLeft(df) { (df, n) => df.withColumn(n + "_map", map(lit(n), col(n))) }
val col_cond = df.columns.map(n => when(not(col(n + "_map").getItem(n).isNull || col(n + "_map").getItem(n) === lit("0")), col(n + "_map")).otherwise(map()))
df2.withColumn("map", map_concat(col_cond: _*))
  .show(false)

ORIGINAL

Here is my try with the function map_from_arrays that is possible to use in spark 2.4+.

df.withColumn("array", array(df.columns.map(col): _*))
  .withColumn("map", map_from_arrays(lit(df.columns), $"array")).show(false)

Then, the result is:

+------+------+------+---------+---------------------------------------+
|field1|field2|field3|array    |map                                    |
+------+------+------+---------+---------------------------------------+
|a     |b     |1     |[a, b, 1]|[field1 -> a, field2 -> b, field3 -> 1]|
|a     |b     |null  |[a, b,]  |[field1 -> a, field2 -> b, field3 ->]  |
|a     |null  |0     |[a,, 0]  |[field1 -> a, field2 ->, field3 -> 0]  |
+------+------+------+---------+---------------------------------------+

Upvotes: 2

Related Questions