Georg Heiler
Georg Heiler

Reputation: 17676

spark aggregating column into Set efficiently

How can I aggregate a column into an Set (Array of unique elements) in spark efficiently?

case class Foo(a:String, b:String, c:Int, d:Array[String])

  val df = Seq(Foo("A", "A", 123, Array("A")),
    Foo("A", "A", 123, Array("B")),
    Foo("B", "B", 123, Array("C", "A")),
    Foo("B", "B", 123, Array("C", "E", "A")),
    Foo("B", "B", 123, Array("D"))
  ).toDS()

Will result in

+---+---+---+---------+
|  a|  b|  c|        d|
+---+---+---+---------+
|  A|  A|123|      [A]|
|  A|  A|123|      [B]|
|  B|  B|123|   [C, A]|
|  B|  B|123|[C, E, A]|
|  B|  B|123|      [D]|
+---+---+---+---------+

what I am Looking for is (ordering of d column is not important):

+---+---+---+------------+
|  a|  b|  c|        d  |
+---+---+---+------------+
|  A|  A|123|   [A, B].  |
|  B|  B|123|[C, A, E, D]|
+---+---+---+------------+

this may be a bit similar to How to aggregate values into collection after groupBy? or the example from HighPerformanceSpark of https://github.com/high-performance-spark/high-performance-spark-examples/blob/57a6267fb77fae5a90109bfd034ae9c18d2edf22/src/main/scala/com/high-performance-spark-examples/transformations/SmartAggregations.scala#L33-L43

Using the following code:

import org.apache.spark.sql.functions.udf
val flatten = udf((xs: Seq[Seq[String]]) => xs.flatten.distinct)
val d = flatten(collect_list($"d")).alias("d")
df.groupBy($"a", $"b", $"c").agg(d).show

will produce the desired result, but I wonder if there are any possibilities to improve performance using the RDD API as outlined in the book. And would like to know how to formulate it using data set API.

Details about the execution for this minimal sample follow below:

== Optimized Logical Plan ==
GlobalLimit 21
+- LocalLimit 21
   +- Aggregate [a#45, b#46, c#47], [a#45, b#46, c#47, UDF(collect_list(d#48, 0, 0)) AS d#82]
      +- LocalRelation [a#45, b#46, c#47, d#48]

== Physical Plan ==
CollectLimit 21
+- SortAggregate(key=[a#45, b#46, c#47], functions=[collect_list(d#48, 0, 0)], output=[a#45, b#46, c#47, d#82])
   +- *Sort [a#45 ASC NULLS FIRST, b#46 ASC NULLS FIRST, c#47 ASC NULLS FIRST], false, 0
      +- Exchange hashpartitioning(a#45, b#46, c#47, 200)
         +- LocalTableScan [a#45, b#46, c#47, d#48]

SQL dag stage DAG

edit

The problems of this operation are outlined very well https://github.com/awesome-spark/spark-gotchas/blob/master/04_rdd_actions_and_transformations_by_example.md#be-smart-about-groupbykey

edit2

As you can see the DAG for the dataSet query suggested below is more complicated and instead of 0.4 seem to take 2 seconds. dag for answer 1

Upvotes: 3

Views: 1664

Answers (1)

Mikel San Vicente
Mikel San Vicente

Reputation: 3863

Try this

df.groupByKey(foo => (foo.a, foo.b, foo.c)).
  reduceGroups{
     (foo1, foo2) => 
        foo1.copy(d = (foo1.d ++ foo2.d).distinct )
  }.map(_._2)

Upvotes: 1

Related Questions