Meyer Cohen
Meyer Cohen

Reputation: 350

How to dynamically add columns to a DataFrame?

I am trying to dynamically add columns to a DataFrame from a Seq of String.

Here's an example : the source dataframe is like:

+-----+---+----+---+---+
|id | A | B  | C  | D  |
+-----+---+----+---+---+
|1 |toto|tata|titi|    |
|2 |bla |blo |   |     |
|3 |b   | c  |  a |  d |
+-----+---+----+---+---+

I also have a Seq of String which contains name of columns I want to add. If a column already exists in the source DataFrame, it must do some kind of difference like below :

The Seq looks like :

val columns = Seq("A", "B", "F", "G", "H")

The expectation is:

+-----+---+----+---+---+---+---+---+
|id | A | B  | C  | D  | F | G | H |
+-----+---+----+---+---+---+---+---+
|1 |toto|tata|titi|tutu|null|null|null
|2 |bla |blo |   |     |null|null|null|
|3 |b   | c  |  a |  d |null|null|null|
+-----+---+----+---+---+---+---+---+

What I've done so far is something like this :

val difference = columns diff sourceDF.columns
val finalDF = difference.foldLeft(sourceDF)((df, field) => if (!sourceDF.columns.contains(field)) df.withColumn(field, lit(null))) else df)
  .select(columns.head, columns.tail:_*) 

But I can't figure how to do this using Spark efficiently in a more simpler and easier way to read ...

Thanks in advance

Upvotes: 2

Views: 2534

Answers (2)

abiratsis
abiratsis

Reputation: 7316

Here is another way using Seq.diff, single select and map to generate your final column list:

import org.apache.spark.sql.functions.{lit, col}


val newCols = Seq("A", "B", "F", "G", "H")

val updatedCols = newCols.diff(df.columns).map{ c => lit(null).as(c)}

val selectExpr = df.columns.map(col) ++ updatedCols

df.select(selectExpr:_*).show

// +---+----+----+----+----+----+----+----+
// | id|   A|   B|   C|   D|   F|   G|   H|
// +---+----+----+----+----+----+----+----+
// |  1|toto|tata|titi|null|null|null|null|
// |  2| bla| blo|null|null|null|null|null|
// |  3|   b|   c|   a|   d|null|null|null|
// +---+----+----+----+----+----+----+----+

First we find the diff between newCols and df.columns this gives us: F, G, H. Next we transform each element of the list to lit(null).as(c) via map function. Finally, we concatenate the existing and the new list together to produce selectExpr which is used for the select.

Upvotes: 3

Nikhil Suthar
Nikhil Suthar

Reputation: 2431

Below will be optimised way with your logic.

scala> df.show
+---+----+----+----+----+
| id|   A|   B|   C|   D|
+---+----+----+----+----+
|  1|toto|tata|titi|null|
|  2| bla| blo|null|null|
|  3|   b|   c|   a|   d|
+---+----+----+----+----+

scala> val Columns  = Seq("A", "B", "F", "G", "H")

scala> val newCol =  Columns filterNot df.columns.toSeq.contains

scala> val df1 =  newCol.foldLeft(df)((df,name) => df.withColumn(name, lit(null)))
scala> df1.show()
+---+----+----+----+----+----+----+----+
| id|   A|   B|   C|   D|   F|   G|   H|
+---+----+----+----+----+----+----+----+
|  1|toto|tata|titi|null|null|null|null|
|  2| bla| blo|null|null|null|null|null|
|  3|   b|   c|   a|   d|null|null|null|
+---+----+----+----+----+----+----+----+

If you do not want to use foldLeft then you can use RunTimeMirror which will be faster. Check Below Code.

scala> import scala.reflect.runtime.universe.runtimeMirror
scala> import scala.tools.reflect.ToolBox
scala> import org.apache.spark.sql.DataFrame

scala> df.show
+---+----+----+----+----+
| id|   A|   B|   C|   D|
+---+----+----+----+----+
|  1|toto|tata|titi|null|
|  2| bla| blo|null|null|
|  3|   b|   c|   a|   d|
+---+----+----+----+----+


scala> def compile[A](code: String): DataFrame => A = {
     |     val tb = runtimeMirror(getClass.getClassLoader).mkToolBox()
     |     val tree = tb.parse(
     |       s"""
     |          |import org.elasticsearch.spark.sql._
     |          |import org.apache.spark.sql.DataFrame
     |          |def wrapper(context:DataFrame): Any = {
     |          |  $code
     |          |}
     |          |wrapper _
     |       """.stripMargin)
     | 
     |     val fun = tb.compile(tree)
     |     val wrapper = fun()
     |     wrapper.asInstanceOf[DataFrame => A]
     |   }


scala> def  AddColumns(df:DataFrame,withColumnsString:String):DataFrame = {
     |     val code =
     |       s"""
     |          |import org.apache.spark.sql.functions._
     |          |import org.elasticsearch.spark.sql._
     |          |import org.apache.spark.sql.DataFrame
     |          |var data = context.asInstanceOf[DataFrame]
     |          |data = data
     |       """ + withColumnsString +
     |         """
     |           |
     |           |data
     |         """.stripMargin
     | 
     |     val fun = compile[DataFrame](code) 
     |     val res = fun(df)
     |     res
     |   }


scala> val Columns = Seq("A", "B", "F", "G", "H")     
scala> val newCol =  Columns filterNot df.columns.toSeq.contains

scala> var cols = ""      
scala>  newCol.foreach{ name =>
     |  cols = ".withColumn(\""+ name + "\" , lit(null))" + cols
     | }

scala> val df1 = AddColumns(df,cols)
scala> df1.show
+---+----+----+----+----+----+----+----+
| id|   A|   B|   C|   D|   H|   G|   F|
+---+----+----+----+----+----+----+----+
|  1|toto|tata|titi|null|null|null|null|
|  2| bla| blo|null|null|null|null|null|
|  3|   b|   c|   a|   d|null|null|null|
+---+----+----+----+----+----+----+----+

Upvotes: 2

Related Questions