Reputation: 198
I have a dataframe as follows
val initialData = Seq(
Row("ABC1",List(Row("Java","XX",120),Row("Scala","XA",300))),
Row("Michael",List(Row("Java","XY",200),Row("Scala","XB",500))),
Row("Robert",List(Row("Java","XZ",400),Row("Scala","XC",250)))
)
val arrayStructSchema = new StructType().add("name",StringType)
.add("SortedDataSet",ArrayType(new StructType()
.add("name",StringType)
.add("author",StringType)
.add("pages",IntegerType)))
val df = spark
.createDataFrame(spark.sparkContext.parallelize(initialData),arrayStructSchema)
df.printSchema()
df.show(5, false)
+-------+-----------------------------------+
|name |SortedDataSet |
+-------+-----------------------------------+
|ABC1 |[[Java, XX, 120], [Scala, XA, 300]]|
|Michael|[[Java, XY, 200], [Scala, XB, 500]]|
|Robert |[[Java, XZ, 400], [Scala, XC, 250]]|
+-------+-----------------------------------+
I need to extract each element of the struct as an individual indexed column Right now, I'm doing the following
val newDf = df.withColumn("Col1", sort_array('SortedDataSet).getItem(0))
.withColumn("Col2", sort_array('SortedDataSet).getItem(1))
.withColumn("name_1",$"Col1.name")
.withColumn("author_1",$"Col1.author")
.withColumn("pages_1",$"Col1.pages")
.withColumn("name_2",$"Col2.name")
.withColumn("author_2",$"Col2.author")
.withColumn("pages_2",$"Col2.pages")
This is simple as I have only 2 arrays and 5 columns. What do I do when I have multiple arrays and columns?
How can I do this programmatically?
Upvotes: 1
Views: 469
Reputation: 22439
One approach would be to flatten the dataframe to generate indexed array elements using posexplode
, followed by a groupBy
and pivot
on the generated indices, like below:
Given the sample dataset:
df.show(false)
// +-------+--------------------------------------------------+
// |name |SortedDataSet |
// +-------+--------------------------------------------------+
// |ABC1 |[[Java, XX, 120], [Scala, XA, 300]] |
// |Michael|[[Java, XY, 200], [Scala, XB, 500], [Go, XD, 600]]|
// |Robert |[[Java, XZ, 400], [Scala, XC, 250]] |
// +-------+--------------------------------------------------+
Note that I've slightly generalized the sample data to showcase arrays with uneven sizes.
val flattenedDF = df.
select($"name", posexplode($"SortedDataSet")).
groupBy($"name").pivot($"pos" + 1).agg(
first($"col.name").as("name"),
first($"col.author").as("author"),
first($"col.pages").as("pages")
)
flattenedDF.show
// +-------+------+--------+-------+------+--------+-------+------+--------+-------+
// | name|1_name|1_author|1_pages|2_name|2_author|2_pages|3_name|3_author|3_pages|
// +-------+------+--------+-------+------+--------+-------+------+--------+-------+
// | ABC1| Java| XX| 120| Scala| XA| 300| null| null| null|
// |Michael| Java| XY| 200| Scala| XB| 500| Go| XD| 600|
// | Robert| Java| XZ| 400| Scala| XC| 250| null| null| null|
// +-------+------+--------+-------+------+--------+-------+------+--------+-------+
To revise the column names created by pivot
to the wanted names:
val pattern = "^\\d+_.*"
val flattenedCols = flattenedDF.columns.filter(_ matches pattern)
def colRenamed(c: String): String =
c.split("_", 2).reverse.mkString("_") // Split on first "_" and switch segments
flattenedDF.
select($"name" +: flattenedCols.map(c => col(c).as(colRenamed(c))): _*).
show
// +-------+------+--------+-------+------+--------+-------+------+--------+-------+
// | name|name_1|author_1|pages_1|name_2|author_2|pages_2|name_3|author_3|pages_3|
// +-------+------+--------+-------+------+--------+-------+------+--------+-------+
// | ABC1| Java| XX| 120| Scala| XA| 300| null| null| null|
// |Michael| Java| XY| 200| Scala| XB| 500| Go| XD| 600|
// | Robert| Java| XZ| 400| Scala| XC| 250| null| null| null|
// +-------+------+--------+-------+------+--------+-------+------+--------+-------+
Upvotes: 1
Reputation: 42342
If your arrays have the same size, you can avoid doing an expensive explode, group by and pivot, by selecting the array and struct elements dynamically:
val arrSize = df.select(size(col("SortedDataSet"))).first().getInt(0)
val df2 = (1 to arrSize).foldLeft(df)(
(d, i) =>
d.withColumn(
s"Col$i",
sort_array(col("SortedDataSet"))(i-1)
)
)
val colNames = df.selectExpr("SortedDataSet[0] as tmp").select("tmp.*").columns
// colNames: Array[String] = Array(name, author, pages)
val colList = (1 to arrSize).map("Col" + _ + ".*").toSeq
// colList: scala.collection.immutable.Seq[String] = Vector(Col1.*, Col2.*)
val colRename = df2.columns ++ (
for {x <- (1 to arrSize); y <- colNames}
yield (x,y)
).map(
x => x._2 + "_" + x._1
).toArray[String]
// colRename: Array[String] = Array(name, SortedDataSet, Col1, Col2, name_1, author_1, pages_1, name_2, author_2, pages_2)
val newDf = df2.select("*", colList: _*).toDF(colRename: _*)
newDf.show(false)
+-------+-----------------------------------+---------------+----------------+------+--------+-------+------+--------+-------+
|name |SortedDataSet |Col1 |Col2 |name_1|author_1|pages_1|name_2|author_2|pages_2|
+-------+-----------------------------------+---------------+----------------+------+--------+-------+------+--------+-------+
|ABC1 |[[Java, XX, 120], [Scala, XA, 300]]|[Java, XX, 120]|[Scala, XA, 300]|Java |XX |120 |Scala |XA |300 |
|Michael|[[Java, XY, 200], [Scala, XB, 500]]|[Java, XY, 200]|[Scala, XB, 500]|Java |XY |200 |Scala |XB |500 |
|Robert |[[Java, XZ, 400], [Scala, XC, 250]]|[Java, XZ, 400]|[Scala, XC, 250]|Java |XZ |400 |Scala |XC |250 |
+-------+-----------------------------------+---------------+----------------+------+--------+-------+------+--------+-------+
Upvotes: 1