Regalia9363
Regalia9363

Reputation: 342

Spark: Performant way to find top n values

I have a large dataset and I would like to find rows with n highest values.

id, count
id1, 10
id2, 15
id3, 5
...

The only method I can think of is using row_number without partition like

val window = Window.orderBy(desc("count"))

df.withColumn("row_number", row_number over window).filter(col("row_number") <= n)

but this is in no way performant when the data contains millions or billions of rows because it pushes the data into one partition and I get OOM.

Has anyone managed to come up with a performant solution?

Upvotes: 3

Views: 3916

Answers (2)

Ged
Ged

Reputation: 18013

  1. Convert to rdd
  2. mapPartitions sorting the data, take N
  3. Convert to df
  4. Then sort and rank and take top N. Unlikely you will have OOM

Actual example, slightly updated, roll your own approach for posterity.

import org.apache.spark.sql.functions._
import spark.sqlContext.implicits._
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.types.{DoubleType, IntegerType, StringType, StructField, StructType}

// 1. Create data
val data = Seq(("James ","","Smith","36636","M",33000),
                 ("Michael ","Rose","","40288","M",14000),
                 ("Robert ","","Williams","42114","M",40),
                 ("Robert ","","Williams","42114","M",540),
                 ("Robert ","","Zeedong","42114","M",40000000),
                 ("Maria ","Anne","Jones","39192","F",300),
                 ("Maria ","Anne","Vangelis","39192","F",1300),
                 ("Jen","Mary","Brown","","F",-1))

val columns = Seq("firstname","middlename","lastname","dob","gender","val")
val df = data.toDF(columns:_*)
//df.show()

//2. Any number of partitions, and sort that partition. Combiner function like Hadoop.
val df2 = df.repartition(1000,col("lastname")).sortWithinPartitions(desc("val")) 
//df2.rdd.glom().collect()
 
//3. Take top N per partition. Thus num partitions x 2 in this case. The take(n) is the top n per partition. No OOM.
val rdd2 = df2.rdd.mapPartitions(_.take(2))

//4. Ghastly Row to DF work-arounds.
val schema = new StructType()
  .add(StructField("f", StringType, true))
  .add(StructField("m", StringType, true))
  .add(StructField("l", StringType, true))
  .add(StructField("d", StringType, true))
  .add(StructField("g", StringType, true))
  .add(StructField("v", IntegerType, true))

val df3 = spark.createDataFrame(rdd2, schema)

//5. Sort and take top(n) = 2 and Bob's your uncle. The Reduce after Combine.
df3.sort(col("v").desc).limit(2).show()

Returns for top 2 desc:

+-------+---+-------+-----+---+--------+
|      f|  m|      l|    d|  g|       v|
+-------+---+-------+-----+---+--------+
|Robert |   |Zeedong|42114|  M|40000000|
| James |   |  Smith|36636|  M|   33000|
+-------+---+-------+-----+---+--------+ 

    

Upvotes: 2

Vincent Doba
Vincent Doba

Reputation: 5068

I see two methods to improve your algorithm performance. First is to use sort and limit to retrieve the top n rows. The second is to develop your custom Aggregator.

Sort and Limit method

You sort your dataframe and then you take the first n rows:

val n: Int = ???

import org.apache.spark.functions.sql.desc

df.orderBy(desc("count")).limit(n)

Spark optimizes this kind of transformations sequence by first performing sort on each partition, taking first n rows on each partition, retrieving it on a final partition and reperforming sort and taking final first n rows. You can check this by executing explain() on transformations. You get the following execution plan:

== Physical Plan ==
TakeOrderedAndProject(limit=3, orderBy=[count#8 DESC NULLS LAST], output=[id#7,count#8])
+- LocalTableScan [id#7, count#8]

And by looking how TakeOrderedAndProject step is executed in limit.scala in Spark's source code (case class TakeOrderedAndProjectExec, method doExecute).

Custom Aggregator method

For custom aggregator, you create an Aggregator that will populate and update an ordered array of top n rows.

import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.Encoder

import scala.collection.mutable.ArrayBuffer

case class Record(id: String, count: Int)

case class TopRecords(limit: Int) extends Aggregator[Record, ArrayBuffer[Record], Seq[Record]] {

  def zero: ArrayBuffer[Record] = ArrayBuffer.empty[Record]

  def reduce(topRecords: ArrayBuffer[Record], currentRecord: Record): ArrayBuffer[Record] = {
    val insertIndex = topRecords.lastIndexWhere(p => p.count > currentRecord.count)
    if (topRecords.length < limit) {
      topRecords.insert(insertIndex + 1, currentRecord)
    } else if (insertIndex < limit - 1) {
      topRecords.insert(insertIndex + 1, currentRecord)
      topRecords.remove(topRecords.length - 1)
    }
    topRecords
  }

  def merge(topRecords1: ArrayBuffer[Record], topRecords2: ArrayBuffer[Record]): ArrayBuffer[Record] = {
    val merged = ArrayBuffer.empty[Record]
    while (merged.length < limit && (topRecords1.nonEmpty || topRecords2.nonEmpty)) {
      if (topRecords1.isEmpty) {
        merged.append(topRecords2.remove(0))
      } else if (topRecords2.isEmpty) {
        merged.append(topRecords1.remove(0))
      } else if (topRecords2.head.count < topRecords1.head.count) {
        merged.append(topRecords1.remove(0))
      } else {
        merged.append(topRecords2.remove(0))
      }
    }
    merged
  }

  def finish(reduction: ArrayBuffer[Record]): Seq[Record] = reduction

  def bufferEncoder: Encoder[ArrayBuffer[Record]] = ExpressionEncoder[ArrayBuffer[Record]]

  def outputEncoder: Encoder[Seq[Record]] = ExpressionEncoder[Seq[Record]]

}

And then you apply this aggregator on your dataframe, and flatten the aggregation result:

val n: Int = ???

import sparkSession.implicits._

df.as[Record].select(TopRecords(n).toColumn).flatMap(record => record)

Method comparison

To compare those two methods, let's say we want to take top n rows of a dataframe that is distributed on p partitions, each partition having around k records. So dataframe has size p·k. Which gives the following complexity (subject to errors):

method total number of operations memory consumption
(on executor)
memory consumption
(on final executor)
Current code O(p·k·log(p·k)) -- O(p·k)
Sort and Limit O(p·k·log(k) + p·n·log(p·n)) O(k) O(p·n)
Custom Aggregator O(p·k) O(k) + O(n) O(p·n)

So regarding number of operations, Custom Aggregator is the most performant. However, this method is by far the most complex and implies lots of serialization/deserialization so it may be less performant than Sort and Limit on certain case.

Conclusion

You have two methods to efficiently take top n rows, Sort and Limit and Custom Aggregator. To select which one to use, you should benchmark those two methods with your real dataframe. If after benchmarking Sort and Limit is a bit slower than Custom aggregator, I would select Sort and Limit as its code is a lot easier to maintain.

Upvotes: 3

Related Questions