danD
danD

Reputation: 716

Need to flatten a dataframe on the basis of one column in Scala

I have a dataframe with a below schema

 root
 |-- name: string (nullable = true)
 |-- roll: string (nullable = true)
 |-- subjectID: string (nullable = true)

The values in the dataframe are as below

+-------------------+---------+--------------------+
|               name|     roll|           SubjectID|
+-------------------+---------+--------------------+
|                sam|ta1i3dfk4|            xy|av|mm|
|               royc|rfhqdbnb3|                   a|
|             alcaly|ta1i3dfk4|               xx|zz|
+-------------------+---------+--------------------+

I need to derive the dataframe by flattenig subject ID as below. please note : SubjectID is string as well

+-------------------+---------+--------------------+
|               name|     roll|           SubjectID|
+-------------------+---------+--------------------+
|                sam|ta1i3dfk4|                  xy|
|                sam|ta1i3dfk4|                  av|
|                sam|ta1i3dfk4|                  mm|
|               royc|rfhqdbnb3|                   a|
|             alcaly|ta1i3dfk4|                  xx|
|             alcaly|ta1i3dfk4|                  zz|
+-------------------+---------+--------------------+

Any suggestion

Upvotes: 0

Views: 92

Answers (2)

vkt
vkt

Reputation: 1459

you can use explode functions to flatten. example:

 val inputDF = Seq(
      ("sam", "ta1i3dfk4", "xy|av|mm"),
      ("royc", "rfhqdbnb3", "a"),
      ("alcaly", "rfhqdbnb3", "xx|zz")
    ).toDF("name", "roll", "subjectIDs")

  //split and explode `subjectIDs`
val result = input.withColumn("subjectIDs",
  split(col("subjectIDs"), "\\|"))
  .withColumn("subjectIDs", explode($"subjectIDs"))

    resultDF.show()

    +------+---------+----------+ 
    |  name|     roll|subjectIDs|
    +------+---------+----------+
    |   sam|ta1i3dfk4|        xy|
    |   sam|ta1i3dfk4|        av|
    |   sam|ta1i3dfk4|        mm|
    |  royc|rfhqdbnb3|         a|
    |alcaly|rfhqdbnb3|        xx|
    |alcaly|rfhqdbnb3|        zz|
    +------+---------+----------+

Upvotes: 2

Mikhail Ionkin
Mikhail Ionkin

Reputation: 617

You can use flatMap on dataset. Full executable code:

package main

import org.apache.spark.sql.{Dataset, SparkSession}

object Main extends App {
  case class Roll(name: Option[String], roll: Option[String], subjectID: Option[String])

  val mySpark = SparkSession
    .builder()
    .master("local[2]")
    .appName("Spark SQL basic example")
    .getOrCreate()
  import mySpark.implicits._

  val inputDF: Dataset[Roll] = Seq(
    ("sam", "ta1i3dfk4", "xy|av|mm"),
    ("royc", "rfhqdbnb3", "a"),
    ("alcaly", "rfhqdbnb3", "xx|zz")
  ).toDF("name", "roll", "subjectID").as[Roll]

  val out: Dataset[Roll] = inputDF.flatMap {
    case Roll(n, r, Some(ids)) if ids.nonEmpty =>
      ids.split("\\|").map(id => Roll(n, r, Some(id)))
    case x => Some(x)
  }
  out.show()
}

Note:

  1. you can use split('|') instead of split("\\|")
  2. you can change default handle if id must be non empty

Upvotes: 1

Related Questions