Reputation: 17676
How can I aggregate a column into an Set (Array of unique elements) in spark efficiently?
case class Foo(a:String, b:String, c:Int, d:Array[String])
val df = Seq(Foo("A", "A", 123, Array("A")),
Foo("A", "A", 123, Array("B")),
Foo("B", "B", 123, Array("C", "A")),
Foo("B", "B", 123, Array("C", "E", "A")),
Foo("B", "B", 123, Array("D"))
).toDS()
Will result in
+---+---+---+---------+
| a| b| c| d|
+---+---+---+---------+
| A| A|123| [A]|
| A| A|123| [B]|
| B| B|123| [C, A]|
| B| B|123|[C, E, A]|
| B| B|123| [D]|
+---+---+---+---------+
what I am Looking for is (ordering of d column is not important):
+---+---+---+------------+
| a| b| c| d |
+---+---+---+------------+
| A| A|123| [A, B]. |
| B| B|123|[C, A, E, D]|
+---+---+---+------------+
this may be a bit similar to How to aggregate values into collection after groupBy? or the example from HighPerformanceSpark
of https://github.com/high-performance-spark/high-performance-spark-examples/blob/57a6267fb77fae5a90109bfd034ae9c18d2edf22/src/main/scala/com/high-performance-spark-examples/transformations/SmartAggregations.scala#L33-L43
Using the following code:
import org.apache.spark.sql.functions.udf
val flatten = udf((xs: Seq[Seq[String]]) => xs.flatten.distinct)
val d = flatten(collect_list($"d")).alias("d")
df.groupBy($"a", $"b", $"c").agg(d).show
will produce the desired result, but I wonder if there are any possibilities to improve performance using the RDD API as outlined in the book. And would like to know how to formulate it using data set API.
Details about the execution for this minimal sample follow below:
== Optimized Logical Plan ==
GlobalLimit 21
+- LocalLimit 21
+- Aggregate [a#45, b#46, c#47], [a#45, b#46, c#47, UDF(collect_list(d#48, 0, 0)) AS d#82]
+- LocalRelation [a#45, b#46, c#47, d#48]
== Physical Plan ==
CollectLimit 21
+- SortAggregate(key=[a#45, b#46, c#47], functions=[collect_list(d#48, 0, 0)], output=[a#45, b#46, c#47, d#82])
+- *Sort [a#45 ASC NULLS FIRST, b#46 ASC NULLS FIRST, c#47 ASC NULLS FIRST], false, 0
+- Exchange hashpartitioning(a#45, b#46, c#47, 200)
+- LocalTableScan [a#45, b#46, c#47, d#48]
The problems of this operation are outlined very well https://github.com/awesome-spark/spark-gotchas/blob/master/04_rdd_actions_and_transformations_by_example.md#be-smart-about-groupbykey
As you can see the DAG for the dataSet query suggested below is more complicated and instead of 0.4 seem to take 2 seconds.
Upvotes: 3
Views: 1664
Reputation: 3863
Try this
df.groupByKey(foo => (foo.a, foo.b, foo.c)).
reduceGroups{
(foo1, foo2) =>
foo1.copy(d = (foo1.d ++ foo2.d).distinct )
}.map(_._2)
Upvotes: 1