Reputation: 33
I'm processing a table like this:
ID f1
001 1
001 2
001 3
002 0
002 7
and I want to calculate the sum of f1 column of the same ID, and create a new column with sum, that is:
ID f1 sum_f1
001 1 6
001 2 6
001 3 6
002 0 7
002 7 7
My solution is calculate the sum with reduceByKey
and then join the result with the original table:
val table = sc.parallelize(Seq(("001",1),("001",2),("001",3),("002",0),("002",7)))
val sum = table.reduceByKey(_ + _)
val result = table.leftOuterJoin(sum).map{ case (a,(b,c)) => (a, b, c.getOrElse(-1) )}
and I get right result:
result.collect.foreach(println)
output:
(002,0,7)
(002,7,7)
(001,1,6)
(001,2,6)
(001,3,6)
The problem is there are 2 shuffle stages in the code, one in reduceByKey, the other in leftOuterJoin, but if I write the code in Hadoop MapReduce, It's easy to get the same result with only 1 shuffle stage(with more than once use of outputer.collect
function in reduce stage).
So I was wondering if there is a better way to do the work with one shuffle. Any suggestion will be appreciated.
Upvotes: 3
Views: 2702
Reputation: 2394
Another approach is to use aggregateByKey
. This may be diffcult to understand method
but from spark docs:
(
groupByKey
) Note: This operation may be very expensive. If you are grouping in order to perform an aggregation (such as a sum or average) over each key, usingPairRDDFunctions.aggregateByKey
orPairRDDFunctions.reduceByKey
will provide much better performance.
Also aggregateByKey
is a generic function so it's worth knowing.
Of course we are not doing "simple aggregation such as sum" here so the
performance benefits of this approach vs groupByKey
may not be present.
Obviously benchmarking both approaches on real data would be a good idea.
Here is detailed implementation:
// The input as given by OP here: http://stackoverflow.com/questions/36455419/spark-reducebykey-and-keep-other-columns
val table = sc.parallelize(Seq(("001", 1), ("001", 2), ("001", 3), ("002", 0), ("002", 7)))
// zero is initial value into which we will aggregate things.
// The second element is the sum.
// The first element is the list of values which contributed to this sum.
val zero = (List.empty[Int], 0)
// sequencer will receive an accumulator and the value.
// The accumulator will be reset for each key to 'zero'.
// In this sequencer we add value to the sum and append to the list because
// we want to keep both.
// This can be thought of as "map" stage in classic map/reduce.
def sequencer(acc: (List[Int], Int), value: Int) = {
val (values, sum) = acc
(value :: values, sum + value)
}
// combiner combines two lists and sums into one.
// The reason for this is the sequencer may run in different partitions
// and thus produce partial results. This step combines those partials into
// one final result.
// This step can be thought of as "reduce" stage in classic map/reduce.
def combiner(left: (List[Int], Int), right: (List[Int], Int)) = {
(left._1 ++ right._1, left._2 + right._2)
}
// wiring it all together.
// Note the type of result it produces:
// Each key will have a list of values which contributed to the sum, sum the sum itself.
val result: RDD[(String, (List[Int], Int))] = table.aggregateByKey(zero)(sequencer, combiner)
// To turn this to a flat list and print, use flatMap to produce:
// (key, value, sum)
val flatResult: RDD[(String, Int, Int)] = result.flatMap(result => {
val (key, (values, sum)) = result
for (value <- values) yield (key, value, sum)
})
// collect and print
flatResult.collect().foreach(println)
This produces:
(001,1,6)
(001,2,6)
(001,3,6)
(002,0,7)
(002,7,7)
Here is also a gist with a fully runnable version of the above if you want to reference it: https://gist.github.com/ppanyukov/253d251a16fbb660f225fb425d32206a
Upvotes: 1
Reputation: 27456
You could use groupByKey
to get the list of values, take the sum and recreate the lines with flatMapValues
:
val g = table.groupByKey().flatMapValues { f1s =>
val sum = f1s.reduce(_ + _)
f1s.map(_ -> sum)
}
But reduce
in this code works locally, so this will fail if a single key has too many values.
Another approach is to keep the join
, but partition first, so the join is cheap:
val partitioned = table.partitionBy(
new org.apache.spark.HashPartitioner(table.partitions.size))
partitioned.cache // May or may not improve performance.
val sum = partitioned.reduceByKey(_ + _)
val result = partitioned.join(sum)
I cannot guess which would be faster. I'd benchmark all of the options.
Upvotes: 0