Reputation: 358
My requirement is to get the top N items from a dataframe.
I've this DataFrame:
val df = List(
("MA", "USA"),
("MA", "USA"),
("OH", "USA"),
("OH", "USA"),
("OH", "USA"),
("OH", "USA"),
("NY", "USA"),
("NY", "USA"),
("NY", "USA"),
("NY", "USA"),
("NY", "USA"),
("NY", "USA"),
("CT", "USA"),
("CT", "USA"),
("CT", "USA"),
("CT", "USA"),
("CT", "USA")).toDF("value", "country")
I was able to map it to an RDD[((Int, String), Long)]
colValCount:
Read: ((colIdx, value), count)
((0,CT),5)
((0,MA),2)
((0,OH),4)
((0,NY),6)
((1,USA),17)
Now I need to get the top 2 items for each column index. So my expected output is this:
RDD[((Int, String), Long)]
((0,CT),5)
((0,NY),6)
((1,USA),17)
I've tried using freqItems api in DataFrame but it's slow.
Any suggestions are welcome.
Upvotes: 4
Views: 29242
Reputation: 1300
You can map each single partition using this helper function defined in Sparkz and then combine them together:
package sparkz.utils
import scala.reflect.ClassTag
object TopElements {
def topN[T: ClassTag](elems: Iterable[T])(scoreFunc: T => Double, n: Int): List[T] =
elems.foldLeft((Set.empty[(T, Double)], Double.MaxValue)) {
case (accumulator@(topElems, minScore), elem) =>
val score = scoreFunc(elem)
if (topElems.size < n)
(topElems + (elem -> score), math.min(minScore, score))
else if (score > minScore) {
val newTopElems = topElems - topElems.minBy(_._2) + (elem -> score)
(newTopElems, newTopElems.map(_._2).min)
}
else accumulator
}
._1.toList.sortBy(_._2).reverse.map(_._1)
}
Source: https://github.com/gm-spacagna/sparkz/blob/master/src/main/scala/sparkz/utils/TopN.scala
Upvotes: 0
Reputation: 28718
The easiest way to do this - a natural window function - is by writing SQL. Spark comes with SQL syntax, and SQL is a great and expressive tool for this problem.
Register your dataframe as a temp table, and then group and window on it.
spark.sql("""SELECT idx, value, ROW_NUMBER() OVER (PARTITION BY idx ORDER BY c DESC) as r
FROM (
SELECT idx, value, COUNT(*) as c
FROM (SELECT 0 as idx, value FROM df UNION ALL SELECT 1, country FROM df)
GROUP BY idx, value)
HAVING r <= 2""").show()
I'd like to see if any of the procedural / scala approaches will let you perform the window function without an iteration or loop. I'm not aware of anything in the Spark API that would support it.
Incidentally, if you have an arbitrary number of columns you want to include then you can quite easily generate the inner section (SELECT 0 as idx, value ... UNION ALL SELECT 1, country
, etc) dynamically using the list of columns.
Upvotes: 3
Reputation: 35229
For example:
import org.apache.spark.sql.functions._
df.select(lit(0).alias("index"), $"value")
.union(df.select(lit(1), $"country"))
.groupBy($"index", $"value")
.count
.orderBy($"count".desc)
.limit(3)
.show
// +-----+-----+-----+
// |index|value|count|
// +-----+-----+-----+
// | 1| USA| 17|
// | 0| NY| 6|
// | 0| CT| 5|
// +-----+-----+-----+
where:
df.select(lit(0).alias("index"), $"value")
.union(df.select(lit(1), $"country"))
creates a two column DataFrame
:
// +-----+-----+
// |index|value|
// +-----+-----+
// | 0| MA|
// | 0| MA|
// | 0| OH|
// | 0| OH|
// | 0| OH|
// | 0| OH|
// | 0| NY|
// | 0| NY|
// | 0| NY|
// | 0| NY|
// | 0| NY|
// | 0| NY|
// | 0| CT|
// | 0| CT|
// | 0| CT|
// | 0| CT|
// | 0| CT|
// | 1| USA|
// | 1| USA|
// | 1| USA|
// +-----+-----+
If you want specifically two values for each column:
import org.apache.spark.sql.DataFrame
def topN(df: DataFrame, key: String, n: Int) = {
df.select(
lit(df.columns.indexOf(key)).alias("index"),
col(key).alias("value"))
.groupBy("index", "value")
.count
.orderBy($"count")
.limit(n)
}
topN(df, "value", 2).union(topN(df, "country", 2)).show
// +-----+-----+-----+
// |index|value|count|
// +-----+-----+-----+
// | 0| MA| 2|
// | 0| OH| 4|
// | 1| USA| 17|
// +-----+-----+-----+
So like pault said - just "some combination of sort()
and limit()
".
Upvotes: 3
Reputation: 61666
Given your last RDD:
val rdd =
sc.parallelize(
List(
((0, "CT"), 5),
((0, "MA"), 2),
((0, "OH"), 4),
((0, "NY"), 6),
((1, "USA"), 17)
))
rdd.filter(_._1._1 == 0).sortBy(-_._2).take(2).foreach(println)
> ((0,NY),6)
> ((0,CT),5)
rdd.filter(_._1._1 == 1).sortBy(-_._2).take(2).foreach(println)
> ((1,USA),17)
We first get items for a given column index (.filter(_._1._1 == 0)
). Then we sort items by decreasing order (.sortBy(-_._2)
). And finally, we take at most the 2 first elements (.take(2)
), which takes only 1 element if the nbr of record is lower than 2.
Upvotes: 2