Reputation: 67
I have a Dataframe that can have multiple columns of Array type like "Array1", "Array2" ... etc. These array columns would have same number of elements. I need to compute a new column of Array type which will be the sum of arrays element wise. How can I do it ?
Spark version = 2.3
For Ex:
Input:
|Column1| ... |ArrayColumn2|ArrayColumn3|
|-------| --- |------------|------------|
|T1 | ... |[1, 2 , 3] | [2, 5, 7]
Output:
|Column1| ... |AggregatedColumn|
|-------| --- |------------|
|T1. | ... |[3, 7 , 10]
No of Array columns are not fixed, thus I need a generalized solution. I would have a list of columns for which I would need to aggregate.
Thanks !
Upvotes: 3
Views: 1032
Reputation: 14845
An SQL expression like array(ArrayColumn2[0]+ArrayColumn3[0], ArrayColumn2[1]+...)
can used to calulate the expected result.
val df = ...
//get all array columns
val arrayCols = df.schema.fields.filter(_.dataType.isInstanceOf[ArrayType]).map(_.name)
//get the size of the first array of the first row
val firstArray = arrayCols(0)
val arraySize = df.selectExpr(s"size($firstArray)").first().getAs[Int](0)
//generate the sql expression for the sums
val sums = (for( i <-0 to arraySize-1)
yield arrayCols.map(c=>s"$c[$i]").mkString("+")).mkString(",")
//sums = ArrayColumn2[0]+ArrayColumn3[0],ArrayColumn2[1]+ArrayColumn3[1],ArrayColumn2[2]+ArrayColumn3[2]
//create a new column using sums
df.withColumn("AggregatedColumn", expr(s"array($sums)")).show()
Output:
+-------+------------+------------+----------------+
|Column1|ArrayColumn2|ArrayColumn3|AggregatedColumn|
+-------+------------+------------+----------------+
| T1| [1, 2, 3]| [2, 5, 7]| [3, 7, 10]|
+-------+------------+------------+----------------+
Using this single (long) SQL expression will avoid shuffling data over the network and thus improve performance.
Upvotes: 1
Reputation: 22439
Consider using inline
and higher-order function aggregate
(available in Spark 2.4+) to compute element-wise sums from the Array-typed columns, followed by a groupBy/agg
to group the element-wise sums back into Arrays:
val df = Seq(
(101, Seq(1, 2), Seq(3, 4), Seq(5, 6)),
(202, Seq(7, 8), Seq(9, 10), Seq(11, 12))
).toDF("id", "arr1", "arr2", "arr3")
val arrCols = df.columns.filter(_.startsWith("arr")).map(col)
For Spark 3.0+
df.
withColumn("arr_structs", arrays_zip(arrCols: _*)).
select($"id", expr("inline(arr_structs)")).
select($"id", aggregate(array(arrCols: _*), lit(0), (acc, x) => acc + x).as("pos_elem_sum")).
groupBy("id").agg(collect_list($"pos_elem_sum").as("arr_elem_sum")).
show
// +---+------------+
// | id|arr_elem_sum|
// +---+------------+
// |101| [9, 12]|
// |202| [27, 30]|
// +---+------------+
For Spark 2.4+
df.
withColumn("arr_structs", arrays_zip(arrCols: _*)).
select($"id", expr("inline(arr_structs)")).
select($"id", array(arrCols: _*).as("arr_pos_elems")).
select($"id", expr("aggregate(arr_pos_elems, 0, (acc, x) -> acc + x)").as("pos_elem_sum")).
groupBy("id").agg(collect_list($"pos_elem_sum").as("arr_elem_sum")).
show
For Spark 2.3 or below
val sumArrElems = udf{ (arr: Seq[Int]) => arr.sum }
df.
withColumn("arr_structs", arrays_zip(arrCols: _*)).
select($"id", expr("inline(arr_structs)")).
select($"id", sumArrElems(array(arrCols: _*)).as("pos_elem_sum")).
groupBy("id").agg(collect_list($"pos_elem_sum").as("arr_elem_sum")).
show
Upvotes: 2