blue-sky
blue-sky

Reputation: 53806

How to sort an RDD in Scala Spark?

Reading Spark method sortByKey :

sortByKey([ascending], [numTasks])   When called on a dataset of (K, V) pairs where K implements Ordered, returns a dataset of (K, V) pairs sorted by keys in ascending or descending order, as specified in the boolean ascending argument.

Is it possible to return just "N" amount of results. So instead of returning all results, just return the top 10. I could convert the sorted collection to an Array and use take method but since this is an O(N) operation is there a more efficient method ?

Upvotes: 34

Views: 42733

Answers (3)

jruizaranguren
jruizaranguren

Reputation: 13597

Another option, at least from PySpark 1.2.0, is the use of takeOrdered.

In ascending order:

rdd.takeOrdered(10)

In descending order:

rdd.takeOrdered(10, lambda x: -x)

Top k values for k,v pairs:

rdd.takeOrdered(10, lambda (k, v): -v)

Upvotes: 8

Daniel Darabos
Daniel Darabos

Reputation: 27455

If you only need the top 10, use rdd.top(10). It avoids sorting, so it is faster.

rdd.top makes one parallel pass through the data, collecting the top N in each partition in a heap, then merges the heaps. It is an O(rdd.count) operation. Sorting would be O(rdd.count log rdd.count), and incur a lot of data transfer — it does a shuffle, so all of the data would be transmitted over the network.

Upvotes: 51

WestCoastProjects
WestCoastProjects

Reputation: 63062

Most likely you have already perused the source code:

  class OrderedRDDFunctions {
   // <snip>
  def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.size): RDD[P] = {
    val part = new RangePartitioner(numPartitions, self, ascending)
    val shuffled = new ShuffledRDD[K, V, P](self, part)
    shuffled.mapPartitions(iter => {
      val buf = iter.toArray
      if (ascending) {
        buf.sortWith((x, y) => x._1 < y._1).iterator
      } else {
        buf.sortWith((x, y) => x._1 > y._1).iterator
      }
    }, preservesPartitioning = true)
  }

And, as you say, the entire data must go through the shuffle stage - as seen in the snippet.

However, your concern about subsequently invoking take(K) may not be so accurate. This operation does NOT cycle through all N items:

  /**
   * Take the first num elements of the RDD. It works by first scanning one partition, and use the
   * results from that partition to estimate the number of additional partitions needed to satisfy
   * the limit.
   */
  def take(num: Int): Array[T] = {

So then, it would seem:

O(myRdd.take(K)) << O(myRdd.sortByKey()) ~= O(myRdd.sortByKey.take(k)) (at least for small K) << O(myRdd.sortByKey().collect()

Upvotes: 19

Related Questions