Reputation: 383
+-----------+----------+-----+
| M | Index|c1 |
+-----------+----------+-----+
|M1 | 0| 224|
|M1 | 1| 748|
|M1 | 3| 56|
+-----------+----------+-----+
I have a DF like above. If I use pivot -
df.groupBy("M").pivot("Index").agg(first("c1"))
, I ll get something like below.But this means I am missing '2' in the series. But,this may be silly but tricky, Is there any way to fill up the column series while doing pivot
+-----------+----+---+---+
| M | 0| 1| 3|
+-----------+----+---+---+
|M1 |224 |748| 56|
+-----------+----+---+---+
Expect Result
+-----------+----+---+---+--+
| M | 0| 1| 2|3 |
+-----------+----+---+---+---
|M1 |224 |748| 0 |56|
+-----------+----+---+---+--+
Upvotes: 1
Views: 343
Reputation: 2518
Welcome to SO @abc_spark,
Supposing you don't have too many indexes in your table, you can try the following approach : Here I compute the max index value across the Dataset. Then for each index from 0 to maxIndex, I create a column with a default value = 0. Note I'm also filling the null values with zeros.
import spark.implicits._
import org.apache.spark.sql.functions._
val df = Seq(
("M1", 0, 224),
("M1", 1, 748),
("M1", 3, 56),
("M2", 3, 213)
).toDF("M", "Index", "c1")
val pivoted = df.groupBy("M").pivot("Index").agg(first("c1")).na.fill(0)
val maxValue = df.select(max($"Index")).collect.head.getAs[Int](0)
val withAllCols = (0 to maxValue).foldLeft(pivoted){case (df, idx) =>
if(df.columns contains idx.toString) df
else df.withColumn(idx.toString, lit(0))
}
withAllCols.show(false)
+---+---+---+---+---+
|M |0 |1 |3 |2 |
+---+---+---+---+---+
|M2 |0 |0 |213|0 |
|M1 |224|748|56 |0 |
+---+---+---+---+---+
Edit : With sorted columns :
withAllCols
.select("M", withAllCols.columns.filterNot(_ == "M").sortBy(_.toInt):_*)
.show(false)
+---+---+---+---+---+
|M |0 |1 |2 |3 |
+---+---+---+---+---+
|M2 |0 |0 |0 |213|
|M1 |224|748|0 |56 |
+---+---+---+---+---+
Upvotes: 2