Reputation: 342
I have a large dataset and I would like to find rows with n highest values.
id, count
id1, 10
id2, 15
id3, 5
The only method I can think of is using row_number
without partition like
val window = Window.orderBy(desc("count"))
df.withColumn("row_number", row_number over window).filter(col("row_number") <= n)
but this is in no way performant when the data contains millions or billions of rows because it pushes the data into one partition and I get OOM.
Has anyone managed to come up with a performant solution?
Upvotes: 3
Views: 3916
Reputation: 18013
Actual example, slightly updated, roll your own approach for posterity.
import org.apache.spark.sql.functions._
import spark.sqlContext.implicits._
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.types.{DoubleType, IntegerType, StringType, StructField, StructType}
// 1. Create data
val data = Seq(("James ","","Smith","36636","M",33000),
("Michael ","Rose","","40288","M",14000),
("Robert ","","Williams","42114","M",40),
("Robert ","","Williams","42114","M",540),
("Robert ","","Zeedong","42114","M",40000000),
("Maria ","Anne","Jones","39192","F",300),
("Maria ","Anne","Vangelis","39192","F",1300),
val columns = Seq("firstname","middlename","lastname","dob","gender","val")
val df = data.toDF(columns:_*)
//2. Any number of partitions, and sort that partition. Combiner function like Hadoop.
val df2 = df.repartition(1000,col("lastname")).sortWithinPartitions(desc("val"))
//3. Take top N per partition. Thus num partitions x 2 in this case. The take(n) is the top n per partition. No OOM.
val rdd2 = df2.rdd.mapPartitions(_.take(2))
//4. Ghastly Row to DF work-arounds.
val schema = new StructType()
.add(StructField("f", StringType, true))
.add(StructField("m", StringType, true))
.add(StructField("l", StringType, true))
.add(StructField("d", StringType, true))
.add(StructField("g", StringType, true))
.add(StructField("v", IntegerType, true))
val df3 = spark.createDataFrame(rdd2, schema)
//5. Sort and take top(n) = 2 and Bob's your uncle. The Reduce after Combine.
Returns for top 2 desc:
| f| m| l| d| g| v|
|Robert | |Zeedong|42114| M|40000000|
| James | | Smith|36636| M| 33000|
Upvotes: 2
Reputation: 5068
I see two methods to improve your algorithm performance. First is to use sort
and limit
to retrieve the top n rows. The second is to develop your custom Aggregator
You sort your dataframe and then you take the first n
val n: Int = ???
import org.apache.spark.functions.sql.desc
Spark optimizes this kind of transformations sequence by first performing sort on each partition, taking first n
rows on each partition, retrieving it on a final partition and reperforming sort and taking final first n
rows. You can check this by executing explain()
on transformations. You get the following execution plan:
== Physical Plan ==
TakeOrderedAndProject(limit=3, orderBy=[count#8 DESC NULLS LAST], output=[id#7,count#8])
+- LocalTableScan [id#7, count#8]
And by looking how TakeOrderedAndProject
step is executed in limit.scala in Spark's source code (case class TakeOrderedAndProjectExec
, method doExecute
For custom aggregator, you create an Aggregator
that will populate and update an ordered array of top n
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.Encoder
import scala.collection.mutable.ArrayBuffer
case class Record(id: String, count: Int)
case class TopRecords(limit: Int) extends Aggregator[Record, ArrayBuffer[Record], Seq[Record]] {
def zero: ArrayBuffer[Record] = ArrayBuffer.empty[Record]
def reduce(topRecords: ArrayBuffer[Record], currentRecord: Record): ArrayBuffer[Record] = {
val insertIndex = topRecords.lastIndexWhere(p => p.count > currentRecord.count)
if (topRecords.length < limit) {
topRecords.insert(insertIndex + 1, currentRecord)
} else if (insertIndex < limit - 1) {
topRecords.insert(insertIndex + 1, currentRecord)
topRecords.remove(topRecords.length - 1)
def merge(topRecords1: ArrayBuffer[Record], topRecords2: ArrayBuffer[Record]): ArrayBuffer[Record] = {
val merged = ArrayBuffer.empty[Record]
while (merged.length < limit && (topRecords1.nonEmpty || topRecords2.nonEmpty)) {
if (topRecords1.isEmpty) {
} else if (topRecords2.isEmpty) {
} else if (topRecords2.head.count < topRecords1.head.count) {
} else {
def finish(reduction: ArrayBuffer[Record]): Seq[Record] = reduction
def bufferEncoder: Encoder[ArrayBuffer[Record]] = ExpressionEncoder[ArrayBuffer[Record]]
def outputEncoder: Encoder[Seq[Record]] = ExpressionEncoder[Seq[Record]]
And then you apply this aggregator on your dataframe, and flatten the aggregation result:
val n: Int = ???
import sparkSession.implicits._[Record].select(TopRecords(n).toColumn).flatMap(record => record)
To compare those two methods, let's say we want to take top n
rows of a dataframe that is distributed on p
partitions, each partition having around k
records. So dataframe has size p·k
. Which gives the following complexity (subject to errors):
method | total number of operations | memory consumption (on executor) |
memory consumption (on final executor) |
Current code | O(p·k·log(p·k)) |
-- | O(p·k) |
Sort and Limit | O(p·k·log(k) + p·n·log(p·n)) |
O(k) |
O(p·n) |
Custom Aggregator | O(p·k) |
O(k) + O(n) |
O(p·n) |
So regarding number of operations, Custom Aggregator is the most performant. However, this method is by far the most complex and implies lots of serialization/deserialization so it may be less performant than Sort and Limit on certain case.
You have two methods to efficiently take top n
rows, Sort and Limit and Custom Aggregator. To select which one to use, you should benchmark those two methods with your real dataframe. If after benchmarking Sort and Limit is a bit slower than Custom aggregator, I would select Sort and Limit as its code is a lot easier to maintain.
Upvotes: 3