samba
samba

Reputation: 3091

How to find an average for a Spark RDD?

I have read that reduce function must be commutative and associative. How should I write a function to find the average so it conforms with this requirement? If I apply the following function to count an average for an RDD it will not count the average correctly. Could anyone explain what is wrong with my function?

I guess that it takes two elements say 1, 2 and applies the function to them like (1+2)/2. Then sums up the result with the next element, 3 and divides it by 2 etc.

val rdd = sc.parallelize(1 to 100)

rdd.reduce((_ + _) / 2)

Upvotes: 1

Views: 11364

Answers (3)

Naveen Rangu
Naveen Rangu

Reputation: 11

check this

val lt = sc.parallelize((List(2,4,5,7,2)))

lt.sum/lt.count

Upvotes: 1

chubock
chubock

Reputation: 844

you can also use PairRDD to keep track of sum of all elements together with counts of elements.

val pair = sc.parallelize(1 to 100)
.map(x => (x, 1))
.reduce((x, y) => (x._1 + y._1, x._2 + y._2))

val mean = pair._1 / pair._2

Upvotes: 1

Leo C
Leo C

Reputation: 22439

rdd.reduce((_ + _) / 2)

There are a few issues with the above reduce method for average calculation:

  1. The placeholder syntax won't work as the shorthand for reduce((acc, x) => (acc + x) / 2)
  2. Since your RDD is of type integer, rdd.reduce((acc, x) => (acc + x) / 2) will result in an integer division in each iteration (certainly incorrect for calculating average)
  3. The reduce method will not produce the average of the list. For example:

    List[Double](1, 2, 3).reduce((a, x) => (a + x) / 2)
    --> (1.0 + 2.0) / 2 = 1.5
    --> (1.5 + 3.0) / 2 = 2.25
    Result: 2.25
    

    whereas:

    Average of List[Double](1, 2, 3) = 2.0
    

How should I write a [reduce] function to find the average so it conforms with this requirement?

I'm not sure reduce is suitable for directly calculating the average of a list. You can certainly use reduce(_ + _) to sum the list then divide the sum by its size, like:

rdd.reduce(_ + _) / rdd.count.toDouble

But then you can simply use RDD's built-in function mean:

rdd.mean

Upvotes: 3

Related Questions