Reputation: 18043
The following:
val pivotDF = df.groupBy("Product").pivot("Country").sum("Amount")
pivotDF.show()
I cannot recall seeing the ability to sort the pivoted column. What is the assumption of sorting? Ascending always. Cannot find it. Non-deterministic?
Tips welcome.
Upvotes: 1
Views: 1023
Reputation: 8711
Using spark-sql:
Extract the pivot cols in an array and sort it ascending or descending as per your wish and then pass it again to the pivot() operator.
scala> val df = Seq(("Foo", "UK", 1), ("Bar", "UK", 5), ("Foo", "FR", 3), ("Bar", "FR", 4))
df: Seq[(String, String, Int)] = List((Foo,UK,1), (Bar,UK,5), (Foo,FR,3), (Bar,FR,4))
scala> val df2= df.toDF("Product", "Country", "Amount")
df2: org.apache.spark.sql.DataFrame = [Product: string, Country: string ... 1 more field]
scala> df2.createOrReplaceTempView("vw1")
scala> df2.groupBy("Product").agg(collect_list("country").alias("country")).selectExpr(""" array_sort(country) """).as[Seq[String]].first
res20: Seq[String] = List(FR, UK)
scala> val pivot_headers = res20.map( d => "'" + d + "'").mkString(",")
pivot_headers: String = 'FR','UK'. // normal sorting ascending
scala> val pivot_headers2 = res20.reverse.map( d => "'" + d + "'").mkString(",")
pivot_headers2: String = 'UK','FR'. // if you want it reversed
scala> spark.sql(s" select * from vw1 pivot(first(amount) for country in (${pivot_headers}) )").show(false)
+-------+---+---+
|Product|FR |UK |
+-------+---+---+
|Bar |4 |5 |
|Foo |3 |1 |
+-------+---+---+
scala> spark.sql(s" select * from vw1 pivot(first(amount) for country in (${pivot_headers2}) )").show(false)
+-------+---+---+
|Product|UK |FR |
+-------+---+---+
|Bar |5 |4 |
|Foo |1 |3 |
+-------+---+---+
Upvotes: 1
Reputation: 4540
According to scala docs:
There are two versions of pivot function: one that requires the caller to specify the list of distinct values to pivot on, and one that does not. The latter is more concise but less efficient, because Spark needs to first compute the list of distinct values internally.
Taking a look how the latter one works
// This is to prevent unintended OOM errors when the number of distinct values is large
val maxValues = df.sparkSession.sessionState.conf.dataFramePivotMaxValues
// Get the distinct values of the column and sort them so its consistent
val values = df.select(pivotColumn)
.distinct()
.limit(maxValues + 1)
.sort(pivotColumn) // ensure that the output columns are in a consistent logical order
.collect()
.map(_.get(0))
.toSeq
and values
is passed to the former version. So when using the version that auto-detects the values, the columns are always sorted using the natural ordering of values. If you need another sorting, it is easy enough to replicate the auto-detection mechanism and then call the version with explicit values:
val df = Seq(("Foo", "UK", 1), ("Bar", "UK", 1), ("Foo", "FR", 1), ("Bar", "FR", 1))
.toDF("Product", "Country", "Amount")
df.groupBy("Product")
.pivot("Country", Seq("UK", "FR")) // natural ordering would be "FR", "UK"
.sum("Amount")
.show()
Output:
+-------+---+---+
|Product| UK| FR|
+-------+---+---+
| Bar| 1| 1|
| Foo| 1| 1|
+-------+---+---+
Upvotes: 2