Manas Mukherjee
Manas Mukherjee

Reputation: 5340

How to add new columns and the corresponding row specific values to a spark dataframe?

I'm new to the Scala/Spark world.

I have a spark dataset(df with a case class) called person.

scala> val person_with_contact = person.map(r => (
     | r.id,
     | r.name,
     | r.age
     | )).toDF()

Now, I want to add a list of address attributes(like apt_no, street, city, zip) to each record of this dataset. The get the list of address attributes, I have a function which takes person's id as input and returns a map that contains all the address attributes and their corresponding values.

I tried the following and a few other Stack Overflow suggested approaches but I couldn't solve it yet. (Ref - static col ex - Spark, add new Column with the same value in Scala)

scala> val person_with_contact = person.map(r => (
    | r.id,
    | r.name,
    | r.age,
    | getAddress(r.id) 
    | )).toDF()

The final dataframe should have the following columns.

id, name, age, apt_no, street, city, zip

Upvotes: 0

Views: 1277

Answers (2)

Hristo Iliev
Hristo Iliev

Reputation: 74355

Given that you already have a function that returns the address as a map, you can create a UDF that converts that map to a struct and then select all map fields:

import org.apache.spark.sql.functions.*

// For demo only
def getAddress(id: Int): Option[Map[String, String]] = {
  id match {
    case 1 => Some(Map("apt_no" -> "12", "street" -> "Main Street", "city" -> "NY", "zip" -> "1234"))
    case 2 => Some(Map("apt_no" -> "1", "street" -> "Back Street", "city" -> "Gotham", "zip" -> "G123"))
    case _ => None
  }
}

case class Address(apt_no: String, street: String, city: String, zip: String)

def getAddressUdf = udf((id: Int) => {
  getAddress(id) map (m =>
    Address(m("apt_no"), m("street"), m("city"), m("zip"))
  )
})

udf() transforms functions that return case class instances into UDFs that return struct columns with the corresponding schema. Option[_] return types are automatically resolved to null-able datatypes. The fields of the struct column can then be expanded into multiple columns by selecting them with $"struct_col_name.*":

scala> val df = Seq(Person(1, "John", 32), Person(2, "Cloe", 27), Person(3, "Pete", 55)).toDS()
df: org.apache.spark.sql.Dataset[Person] = [id: int, name: string ... 1 more field]

scala> df.show()
+---+----+---+
| id|name|age|
+---+----+---+
|  1|John| 32|
|  2|Cloe| 27|
|  3|Pete| 55|
+---+----+---+

scala> df
     | .withColumn("addr", getAddressUdf($"id"))
     | .select($"id", $"name", $"age", $"addr.*")
     | .show()
+---+----+---+------+------------+------+-----+
| id|name|age|apt_no|      street|  city|  zip|
+---+----+---+------+------------+------+-----+
|  1|John| 32|    12| Main Street|    NY| 1234|
|  2|Cloe| 27|     1| Back Street|Gotham| G123|
|  3|Pete| 55|  null|        null|  null| null|
+---+----+---+------+------------+------+-----+

Upvotes: 0

wenjiangFu
wenjiangFu

Reputation: 11

use udf

package yourpackage

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._


object MainDemo {

  def getAddress(id: Int): String = {
    //do your things
    "address id:" + id
  }

  def getCity(id: String): String = {
    //do your things
    "your city :" + id
  }

  def getZip(id: String): String = {
    //do your things
    "your zip :" + id
  }

  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder().appName(this.getClass.getSimpleName).master("local[3]").enableHiveSupport().getOrCreate()
    val person = Seq(Person(1, "name_m", 21), Person(2, "name_w", 40))
    import spark.implicits._
    val person_with_contact = person.map(r => (r.id, r.name, r.age, getAddress(r.id))).toDF("id", "name", "age", "street")
    person_with_contact.printSchema()
    //root
    // |-- id: integer (nullable = false)
    // |-- name: string (nullable = true)
    // |-- age: integer (nullable = false)
    // |-- street: string (nullable = true)
    val result = person_with_contact.select(
      col("id"),
      col("name"),
      col("age"),
      col("street"),
      udf { id: String =>
        val city = getCity(id)
        city
      }.apply(col("id")).as("city"),
      udf { id: String =>
        val city = getZip(id)
        city
      }.apply(col("id")).as("zip")
    )
    result.printSchema()
    //root
    // |-- id: integer (nullable = false)
    // |-- name: string (nullable = true)
    // |-- age: integer (nullable = false)
    // |-- street: string (nullable = true)
    // |-- city: string (nullable = true)
    // |-- zip: string (nullable = true)
    result.show()
    //+---+------+---+------------+------------+-----------+
    //| id|  name|age|      street|        city|        zip|
    //+---+------+---+------------+------------+-----------+
    //|  1|name_m| 21|address id:1|your city :1|your zip :1|
    //|  2|name_w| 40|address id:2|your city :2|your zip :2|
    //+---+------+---+------------+------------+-----------+
  }
}

Upvotes: 1

Related Questions