Sam
Sam

Reputation: 358

Top N items from a Spark DataFrame/RDD

My requirement is to get the top N items from a dataframe.

I've this DataFrame:

val df = List(
      ("MA", "USA"),
      ("MA", "USA"),
      ("OH", "USA"),
      ("OH", "USA"),
      ("OH", "USA"),
      ("OH", "USA"),
      ("NY", "USA"),
      ("NY", "USA"),
      ("NY", "USA"),
      ("NY", "USA"),
      ("NY", "USA"),
      ("NY", "USA"),
      ("CT", "USA"),
      ("CT", "USA"),
      ("CT", "USA"),
      ("CT", "USA"),
      ("CT", "USA")).toDF("value", "country")

I was able to map it to an RDD[((Int, String), Long)] colValCount: Read: ((colIdx, value), count)

((0,CT),5)
((0,MA),2)
((0,OH),4)
((0,NY),6)
((1,USA),17)

Now I need to get the top 2 items for each column index. So my expected output is this:

RDD[((Int, String), Long)]

((0,CT),5)
((0,NY),6)
((1,USA),17)

I've tried using freqItems api in DataFrame but it's slow.

Any suggestions are welcome.

Upvotes: 4

Views: 29242

Answers (4)

Gianmario Spacagna
Gianmario Spacagna

Reputation: 1300

You can map each single partition using this helper function defined in Sparkz and then combine them together:

package sparkz.utils

import scala.reflect.ClassTag

object TopElements {
  def topN[T: ClassTag](elems: Iterable[T])(scoreFunc: T => Double, n: Int): List[T] =
    elems.foldLeft((Set.empty[(T, Double)], Double.MaxValue)) {
      case (accumulator@(topElems, minScore), elem) =>
        val score = scoreFunc(elem)
        if (topElems.size < n)
          (topElems + (elem -> score), math.min(minScore, score))
        else if (score > minScore) {
          val newTopElems = topElems - topElems.minBy(_._2) + (elem -> score)
          (newTopElems, newTopElems.map(_._2).min)
        }
        else accumulator
    }
      ._1.toList.sortBy(_._2).reverse.map(_._1)
}

Source: https://github.com/gm-spacagna/sparkz/blob/master/src/main/scala/sparkz/utils/TopN.scala

Upvotes: 0

Kirk Broadhurst
Kirk Broadhurst

Reputation: 28718

The easiest way to do this - a natural window function - is by writing SQL. Spark comes with SQL syntax, and SQL is a great and expressive tool for this problem.

Register your dataframe as a temp table, and then group and window on it.

spark.sql("""SELECT idx, value, ROW_NUMBER() OVER (PARTITION BY idx ORDER BY c DESC) as r 
             FROM (
               SELECT idx, value, COUNT(*) as c 
               FROM (SELECT 0 as idx, value FROM df UNION ALL SELECT 1, country FROM df) 
               GROUP BY idx, value) 
             HAVING r <= 2""").show()

I'd like to see if any of the procedural / scala approaches will let you perform the window function without an iteration or loop. I'm not aware of anything in the Spark API that would support it.

Incidentally, if you have an arbitrary number of columns you want to include then you can quite easily generate the inner section (SELECT 0 as idx, value ... UNION ALL SELECT 1, country, etc) dynamically using the list of columns.

Upvotes: 3

Alper t. Turker
Alper t. Turker

Reputation: 35229

For example:

import org.apache.spark.sql.functions._

df.select(lit(0).alias("index"), $"value")
   .union(df.select(lit(1), $"country"))
   .groupBy($"index", $"value")
   .count
  .orderBy($"count".desc)
  .limit(3)
  .show

// +-----+-----+-----+
// |index|value|count|
// +-----+-----+-----+
// |    1|  USA|   17|
// |    0|   NY|    6|
// |    0|   CT|    5|
// +-----+-----+-----+

where:

df.select(lit(0).alias("index"), $"value")
  .union(df.select(lit(1), $"country"))

creates a two column DataFrame:

// +-----+-----+
// |index|value|
// +-----+-----+
// |    0|   MA|
// |    0|   MA|
// |    0|   OH|
// |    0|   OH|
// |    0|   OH|
// |    0|   OH|
// |    0|   NY|
// |    0|   NY|
// |    0|   NY|
// |    0|   NY|
// |    0|   NY|
// |    0|   NY|
// |    0|   CT|
// |    0|   CT|
// |    0|   CT|
// |    0|   CT|
// |    0|   CT|
// |    1|  USA|
// |    1|  USA|
// |    1|  USA|
// +-----+-----+

If you want specifically two values for each column:

import org.apache.spark.sql.DataFrame

def topN(df: DataFrame, key: String, n: Int)  = {
   df.select(
        lit(df.columns.indexOf(key)).alias("index"), 
        col(key).alias("value"))
     .groupBy("index", "value")
     .count
     .orderBy($"count")
     .limit(n)
}

topN(df, "value", 2).union(topN(df, "country", 2)).show
// +-----+-----+-----+ 
// |index|value|count|
// +-----+-----+-----+
// |    0|   MA|    2|
// |    0|   OH|    4|
// |    1|  USA|   17|
// +-----+-----+-----+

So like pault said - just "some combination of sort() and limit()".

Upvotes: 3

Xavier Guihot
Xavier Guihot

Reputation: 61666

Given your last RDD:

val rdd =
  sc.parallelize(
    List(
      ((0, "CT"), 5),
      ((0, "MA"), 2),
      ((0, "OH"), 4),
      ((0, "NY"), 6),
      ((1, "USA"), 17)
    ))

rdd.filter(_._1._1 == 0).sortBy(-_._2).take(2).foreach(println)
> ((0,NY),6)
> ((0,CT),5)
rdd.filter(_._1._1 == 1).sortBy(-_._2).take(2).foreach(println)
> ((1,USA),17)

We first get items for a given column index (.filter(_._1._1 == 0)). Then we sort items by decreasing order (.sortBy(-_._2)). And finally, we take at most the 2 first elements (.take(2)), which takes only 1 element if the nbr of record is lower than 2.

Upvotes: 2

Related Questions