vinay patil
vinay patil

Reputation: 1

How to convert Columns to rows in Spark scala or spark sql?

I have the Data like this.

+------+------+------+----------+----------+----------+----------+----------+----------+
| Col1 | Col2 | Col3 | Col1_cnt | Col2_cnt | Col3_cnt | Col1_wts | Col2_wts | Col3_wts |
+------+------+------+----------+----------+----------+----------+----------+----------+
| AAA  | VVVV | SSSS |        3 |        4 |        5 |      0.5 |      0.4 |      0.6 |
| BBB  | BBBB | TTTT |        3 |        4 |        5 |      0.5 |      0.4 |      0.6 |
| CCC  | DDDD | YYYY |        3 |        4 |        5 |      0.5 |      0.4 |      0.6 |
+------+------+------+----------+----------+----------+----------+----------+----------+

I have tried but I am not getting any help here.

val df = Seq(("G",Some(4),2,None),("H",None,4,Some(5))).toDF("A","X","Y", "Z")

I want the output in the form of below table

+-----------+---------+---------+
| Cols_name | Col_cnt | Col_wts |
+-----------+---------+---------+
| Col1      |       3 |     0.5 |
| Col2      |       4 |     0.4 |
| Col3      |       5 |     0.6 |
+-----------+---------+---------+

Upvotes: 0

Views: 3122

Answers (1)

Leo C
Leo C

Reputation: 22439

Here's a general approach for transposing a DataFrame:

  1. For each of the pivot columns (say c1, c2, c3), combine the column name and associated value columns into a struct (e.g. struct(lit(c1), c1_cnt, c1_wts))
  2. Put all these struct-typed columns into an array which is then explode-ed into rows of struct columns
  3. Group by the pivot column name to aggregate the associated struct elements

The following sample code has been generalized to handle an arbitrary list of columns to be transposed:

import org.apache.spark.sql.functions._
import spark.implicits._

val df = Seq(
  ("AAA", "VVVV", "SSSS", 3, 4, 5, 0.5, 0.4, 0.6),
  ("BBB", "BBBB", "TTTT", 3, 4, 5, 0.5, 0.4, 0.6),
  ("CCC", "DDDD", "YYYY", 3, 4, 5, 0.5, 0.4, 0.6)
).toDF("c1", "c2", "c3", "c1_cnt", "c2_cnt", "c3_cnt", "c1_wts", "c2_wts", "c3_wts")

val pivotCols = Seq("c1", "c2", "c3")

val valueColSfx = Seq("_cnt", "_wts")

val arrStructs = pivotCols.map{ c => struct(
    Seq(lit(c).as("_pvt")) ++
      valueColSfx.map((c, _)).map{ case (p, s) => col(p + s).as(s) }: _*
  ).as(c + "_struct")
}

val valueColAgg = valueColSfx.map(s => first($"struct_col.$s").as(s + "_first"))

df.
  select(array(arrStructs: _*).as("arr_structs")).
  withColumn("struct_col", explode($"arr_structs")).
  groupBy($"struct_col._pvt").agg(valueColAgg.head, valueColAgg.tail: _*).
  show
// +----+----------+----------+
// |_pvt|_cnt_first|_wts_first|
// +----+----------+----------+
// |  c1|         3|       0.5|
// |  c3|         5|       0.6|
// |  c2|         4|       0.4|
// +----+----------+----------+

Note that function first is used in the above example, but it could be any other aggregate function (e.g. avg, max, collect_list) depending on the specific business requirement.

Upvotes: 1

Related Questions