Ajay
Ajay

Reputation: 198

Programmatically extract columns from Struct column as individual columns

I have a dataframe as follows

val initialData = Seq(
    Row("ABC1",List(Row("Java","XX",120),Row("Scala","XA",300))),
    Row("Michael",List(Row("Java","XY",200),Row("Scala","XB",500))),
    Row("Robert",List(Row("Java","XZ",400),Row("Scala","XC",250)))
)
    
val arrayStructSchema = new StructType().add("name",StringType)
.add("SortedDataSet",ArrayType(new StructType()
.add("name",StringType)
.add("author",StringType)
.add("pages",IntegerType)))

val df = spark
.createDataFrame(spark.sparkContext.parallelize(initialData),arrayStructSchema)

df.printSchema()
df.show(5, false)

+-------+-----------------------------------+
|name   |SortedDataSet                      |
+-------+-----------------------------------+
|ABC1   |[[Java, XX, 120], [Scala, XA, 300]]|
|Michael|[[Java, XY, 200], [Scala, XB, 500]]|
|Robert |[[Java, XZ, 400], [Scala, XC, 250]]|
+-------+-----------------------------------+

I need to extract each element of the struct as an individual indexed column Right now, I'm doing the following

val newDf = df.withColumn("Col1", sort_array('SortedDataSet).getItem(0))
.withColumn("Col2", sort_array('SortedDataSet).getItem(1))
.withColumn("name_1",$"Col1.name")
.withColumn("author_1",$"Col1.author")
.withColumn("pages_1",$"Col1.pages")
.withColumn("name_2",$"Col2.name")
.withColumn("author_2",$"Col2.author")
.withColumn("pages_2",$"Col2.pages")

This is simple as I have only 2 arrays and 5 columns. What do I do when I have multiple arrays and columns?

How can I do this programmatically?

Upvotes: 1

Views: 469

Answers (2)

Leo C
Leo C

Reputation: 22439

One approach would be to flatten the dataframe to generate indexed array elements using posexplode, followed by a groupBy and pivot on the generated indices, like below:

Given the sample dataset:

df.show(false)
// +-------+--------------------------------------------------+
// |name   |SortedDataSet                                     |
// +-------+--------------------------------------------------+
// |ABC1   |[[Java, XX, 120], [Scala, XA, 300]]               |
// |Michael|[[Java, XY, 200], [Scala, XB, 500], [Go, XD, 600]]|
// |Robert |[[Java, XZ, 400], [Scala, XC, 250]]               |
// +-------+--------------------------------------------------+

Note that I've slightly generalized the sample data to showcase arrays with uneven sizes.

val flattenedDF = df.
  select($"name", posexplode($"SortedDataSet")).
  groupBy($"name").pivot($"pos" + 1).agg(
    first($"col.name").as("name"),
    first($"col.author").as("author"),
    first($"col.pages").as("pages")
  )

flattenedDF.show
// +-------+------+--------+-------+------+--------+-------+------+--------+-------+
// |   name|1_name|1_author|1_pages|2_name|2_author|2_pages|3_name|3_author|3_pages|
// +-------+------+--------+-------+------+--------+-------+------+--------+-------+
// |   ABC1|  Java|      XX|    120| Scala|      XA|    300|  null|    null|   null|
// |Michael|  Java|      XY|    200| Scala|      XB|    500|    Go|      XD|    600|
// | Robert|  Java|      XZ|    400| Scala|      XC|    250|  null|    null|   null|
// +-------+------+--------+-------+------+--------+-------+------+--------+-------+

To revise the column names created by pivot to the wanted names:

val pattern = "^\\d+_.*"
val flattenedCols = flattenedDF.columns.filter(_ matches pattern)

def colRenamed(c: String): String =
  c.split("_", 2).reverse.mkString("_")  // Split on first "_" and switch segments

flattenedDF.
  select($"name" +: flattenedCols.map(c => col(c).as(colRenamed(c))): _*).
  show
// +-------+------+--------+-------+------+--------+-------+------+--------+-------+
// |   name|name_1|author_1|pages_1|name_2|author_2|pages_2|name_3|author_3|pages_3|
// +-------+------+--------+-------+------+--------+-------+------+--------+-------+
// |   ABC1|  Java|      XX|    120| Scala|      XA|    300|  null|    null|   null|
// |Michael|  Java|      XY|    200| Scala|      XB|    500|    Go|      XD|    600|
// | Robert|  Java|      XZ|    400| Scala|      XC|    250|  null|    null|   null|
// +-------+------+--------+-------+------+--------+-------+------+--------+-------+

Upvotes: 1

mck
mck

Reputation: 42342

If your arrays have the same size, you can avoid doing an expensive explode, group by and pivot, by selecting the array and struct elements dynamically:

val arrSize = df.select(size(col("SortedDataSet"))).first().getInt(0)

val df2 = (1 to arrSize).foldLeft(df)(
    (d, i) => 
    d.withColumn(
        s"Col$i", 
        sort_array(col("SortedDataSet"))(i-1)
    )
)

val colNames = df.selectExpr("SortedDataSet[0] as tmp").select("tmp.*").columns
// colNames: Array[String] = Array(name, author, pages)

val colList = (1 to arrSize).map("Col" + _ + ".*").toSeq
// colList: scala.collection.immutable.Seq[String] = Vector(Col1.*, Col2.*)

val colRename = df2.columns ++ (
    for {x <- (1 to arrSize); y <- colNames} 
    yield (x,y)
).map(
    x => x._2 + "_" + x._1
).toArray[String]
// colRename: Array[String] = Array(name, SortedDataSet, Col1, Col2, name_1, author_1, pages_1, name_2, author_2, pages_2)

val newDf = df2.select("*", colList: _*).toDF(colRename: _*)

newDf.show(false)
+-------+-----------------------------------+---------------+----------------+------+--------+-------+------+--------+-------+
|name   |SortedDataSet                      |Col1           |Col2            |name_1|author_1|pages_1|name_2|author_2|pages_2|
+-------+-----------------------------------+---------------+----------------+------+--------+-------+------+--------+-------+
|ABC1   |[[Java, XX, 120], [Scala, XA, 300]]|[Java, XX, 120]|[Scala, XA, 300]|Java  |XX      |120    |Scala |XA      |300    |
|Michael|[[Java, XY, 200], [Scala, XB, 500]]|[Java, XY, 200]|[Scala, XB, 500]|Java  |XY      |200    |Scala |XB      |500    |
|Robert |[[Java, XZ, 400], [Scala, XC, 250]]|[Java, XZ, 400]|[Scala, XC, 250]|Java  |XZ      |400    |Scala |XC      |250    |
+-------+-----------------------------------+---------------+----------------+------+--------+-------+------+--------+-------+

Upvotes: 1

Related Questions