Achilles
Achilles

Reputation: 741

Pivoting DataFrame - Spark SQL

I have a DataFrame containing below:

TradeId|Source
ABC|"USD,333.123,20170605|USD,-789.444,20170605|GBP,1234.567,20150602"

I want to pivot this data so it turns into below

TradeId|CCY|PV
ABC|USD|333.123
ABC|USD|-789.444
ABC|GBP|1234.567

The number of CCY|PV|Date triplets in the column "Source" is not fixed. I could do it in ArrayList but that requires to load the data in JVM and defeats the whole point of Spark.

Lets say my DataFrame looks as below:

DataFrame tradesSnap = this.loadTradesSnap(reportRequest);
String tempTable = getTempTableName();
tradesSnap.registerTempTable(tempTable);
tradesSnap = tradesSnap.sqlContext().sql("SELECT TradeId, Source FROM " + tempTable);

Upvotes: 0

Views: 676

Answers (2)

Ramesh Maharjan
Ramesh Maharjan

Reputation: 41987

If you read databricks pivot, it says " A pivot is an aggregation where one (or more in the general case) of the grouping columns has its distinct values transposed into individual columns." And this is not what you desire I guess

I would suggest you to use withColumn and functions to get the final output you desire. You can do as following considering dataframe is what you have

+-------+----------------------------------------------------------------+
|TradeId|Source                                                          |
+-------+----------------------------------------------------------------+
|ABC    |USD,333.123,20170605|USD,-789.444,20170605|GBP,1234.567,20150602|
+-------+----------------------------------------------------------------+

You can do the following using explode, split and withColumn to get the desired output

val explodedDF = dataframe.withColumn("Source", explode(split(col("Source"), "\\|")))
val finalDF = explodedDF.withColumn("CCY", split($"Source", ",")(0))
  .withColumn("PV", split($"Source", ",")(1))
  .withColumn("Date",  split($"Source", ",")(2))
  .drop("Source")

finalDF.show(false)

The final output is

+-------+---+--------+--------+
|TradeId|CCY|PV      |Date    |
+-------+---+--------+--------+
|ABC    |USD|333.123 |20170605|
|ABC    |USD|-789.444|20170605|
|ABC    |GBP|1234.567|20150602|
+-------+---+--------+--------+

I hope this solves your issue

Upvotes: 2

stefanobaghino
stefanobaghino

Reputation: 12814

Rather than pivoting, what you are trying to achieve looks more like flatMap.

To put it simply, by using flatMap on a Dataset you apply to each row a function (map) that itself would produce a sequence of rows. Each set of rows is then concatenated into a single sequence (flat).

The following program shows the idea:

import org.apache.spark.sql.SparkSession

case class Input(TradeId: String, Source: String)

case class Output(TradeId: String, CCY: String, PV: String, Date: String)

object FlatMapExample {

  // This function will produce more rows of output for each line of input
  def splitSource(in: Input): Seq[Output] =
    in.Source.split("\\|", -1).map {
      source =>
        println(source)
        val Array(ccy, pv, date) = source.split(",", -1)
        Output(in.TradeId, ccy, pv, date)
    }

  def main(args: Array[String]): Unit = {

    // Initialization and loading
    val spark = SparkSession.builder().master("local").appName("pivoting-example").getOrCreate()
    import spark.implicits._
    val input = spark.read.options(Map("sep" -> "|", "header" -> "true")).csv(args(0)).as[Input]

    // For each line in the input, split the source and then 
    // concatenate each "sub-sequence" in a single `Dataset`
    input.flatMap(splitSource).show
  }

}

Given your input, this would be the output:

+-------+---+--------+--------+
|TradeId|CCY|      PV|    Date|
+-------+---+--------+--------+
|    ABC|USD| 333.123|20170605|
|    ABC|USD|-789.444|20170605|
|    ABC|GBP|1234.567|20150602|
+-------+---+--------+--------+

You can now take the result and save it to a CSV, if you want.

Upvotes: 2

Related Questions