Peh Qin Cheng
Peh Qin Cheng

Reputation: 73

How to flatten Array of WrappedArray of structs in scala

I have a dataframe with the following schema:

root
 |-- id: string (nullable = true)
 |-- collect_list(typeCounts): array (nullable = true)
 |    |-- element: array (containsNull = true)
 |    |    |-- element: struct (containsNull = true)
 |    |    |    |-- type: string (nullable = true)
 |    |    |    |-- count: long (nullable = false)

Example data:

+-----------+----------------------------------------------------------------------------+
|id         |collect_list(typeCounts)                                                    |
+-----------+----------------------------------------------------------------------------+
|1          |[WrappedArray([B00XGS,6], [B001FY,5]), WrappedArray([B06LJ7,4])]|
|2          |[WrappedArray([B00UFY,3])]                                              |
+-----------+----------------------------------------------------------------------------+

How can I flatten collect_list(typeCounts) to a flat array of structs in scala? I have read some answers on stackoverflow for similar questions suggesting UDF's, but I am not sure what the UDF method signature should be for structs.

Upvotes: 0

Views: 602

Answers (1)

Leo C
Leo C

Reputation: 22439

If you're on Spark 2.4+, instead of using a UDF (which is generally less efficient than native Spark functions) you can apply flatten, like below:

df.withColumn("collect_list(typeCounts)", flatten($"collect_list(typeCounts)"))

i am not sure what the udf method signature should be for structs

UDF takes structs as Rows for input and may return them as Scala case classes. To flatten the nested collections, you can create a simple UDF as follows:

import org.apache.spark.sql.Row

case class TC(`type`: String, count: Long)

val flattenLists = udf{ (lists: Seq[Seq[Row]]) =>
  lists.flatMap( _.map{ case Row(t: String, c: Long) => TC(t, c) } )
}

To test out the UDF, let's assemble a DataFrame with your described schema:

val df = Seq(
    ("1", Seq(TC("B00XGS", 6), TC("B001FY", 5))),
    ("1", Seq(TC("B06LJ7", 4))),
    ("2", Seq(TC("B00UFY", 3)))
  ).toDF("id", "typeCounts").
  groupBy("id").agg(collect_list("typeCounts"))

df.printSchema
// root
//  |-- id: string (nullable = true)
//  |-- collect_list(typeCounts): array (nullable = true)
//  |    |-- element: array (containsNull = true)
//  |    |    |-- element: struct (containsNull = true)
//  |    |    |    |-- type: string (nullable = true)
//  |    |    |    |-- count: long (nullable = false)

Applying the UDF:

df.
  withColumn("collect_list(typeCounts)", flattenLists($"collect_list(typeCounts)")).
  printSchema
// root
//  |-- id: string (nullable = true)
//  |-- collect_list(typeCounts): array (nullable = true)
//  |    |-- element: struct (containsNull = true)
//  |    |    |-- type: string (nullable = true)
//  |    |    |-- count: long (nullable = false)

Upvotes: 1

Related Questions