anthonybell
anthonybell

Reputation: 5998

Spark - how to get top N of rdd as a new rdd (without collecting at the driver)

I am wondering how to filter an RDD that has one of the top N values. Usually I would sort the RDD and take the top N items as an array in the driver to find the Nth value that can be broadcasted to filter the rdd like so:

val topNvalues = sc.broadcast(rdd.map(_.fieldToThreshold).distict.sorted.take(N))
val threshold = topNvalues.last
val rddWithTopNValues = rdd.filter(_.fieldToThreshold >= threshold)

but in this case my N is too large, so how can I do this purely with RDDs like so?:

def getExpensiveItems(itemPrices: RDD[(Int, Float)], count: Int): RDD[(Int, Float)] = {
     val sortedPrices = itemPrices.sortBy(-_._2).map(_._1).distinct

     // How to do this without collecting results to driver??
     val highPrices = itemPrices.getTopNValuesWithoutCollect(count)

     itemPrices.join(highPrices.keyBy(x => x)).map(_._2._1)
}

Upvotes: 8

Views: 3687

Answers (1)

elm
elm

Reputation: 20435

Use zipWithIndex on the sorted rdd and then filter by the index up to n items. To illustrate the case consider this rrd sorted in descending order,

val rdd = sc.parallelize((1 to 10).map( _ => math.random)).sortBy(-_)

Then

rdd.zipWithIndex.filter(_._2 < 4)

delivers the first top four items without collecting the rdd to the driver.

Upvotes: 10

Related Questions