Reputation: 945
What I want to do is given a DataFrame, take top n elements according to some specified column. The top(self, num) in RDD API is exactly what I want. I wonder if there is equivalent API in DataFrame world ?
My first attempt is the following
def retrieve_top_n(df, n):
# assume we want to get most popular n 'key' in DataFrame
return df.groupBy('key').count().orderBy('count', ascending=False).limit(n).select('key')
However, I've realized that this results in non-deterministic behavior (I don't know the exact reason but I guess limit(n) doesn't guarantee which n to take)
Upvotes: 6
Views: 13214
Reputation: 28093
import numpy as np
def sample_df(num_records):
def data():
np.random.seed(42)
while True:
yield int(np.random.normal(100., 80.))
data_iter = iter(data())
df = sc.parallelize((
(i, next(data_iter)) for i in range(int(num_records))
)).toDF(('index', 'key_col'))
return df
sample_df(1e3).show(n=5)
+-----+-------+
|index|key_col|
+-----+-------+
| 0| 139|
| 1| 88|
| 2| 151|
| 3| 221|
| 4| 81|
+-----+-------+
only showing top 5 rows
from pyspark.sql import Window
from pyspark.sql import functions
def top_df_0(df, key_col, K):
"""
Using window functions. Handles ties OK.
"""
window = Window.orderBy(functions.col(key_col).desc())
return (df
.withColumn("rank", functions.rank().over(window))
.filter(functions.col('rank') <= K)
.drop('rank'))
def top_df_1(df, key_col, K):
"""
Using limit(K). Does NOT handle ties appropriately.
"""
return df.orderBy(functions.col(key_col).desc()).limit(K)
def top_df_2(df, key_col, K):
"""
Using limit(k) and then filtering. Handles ties OK."
"""
num_records = df.count()
value_at_k_rank = (df
.orderBy(functions.col(key_col).desc())
.limit(k)
.select(functions.min(key_col).alias('min'))
.first()['min'])
return df.filter(df[key_col] >= value_at_k_rank)
The function called top_df_1
is similar to the one you originally implemented. The reason it gives you non-deterministic behavior is because it cannot handle ties nicely. This may be an OK thing to do if you have lots of data and are only interested in an approximate answer for the sake of performance.
For benchmarking use a Spark DF with 4 million entries and define a convenience function:
NUM_RECORDS = 4e6
test_df = sample_df(NUM_RECORDS).cache()
def show(func, df, key_col, K):
func(df, key_col, K).select(
functions.max(key_col),
functions.min(key_col),
functions.count(key_col)
).show()
Let's see the verdict:
%timeit show(top_df_0, test_df, "key_col", K=100)
+------------+------------+--------------+
|max(key_col)|min(key_col)|count(key_col)|
+------------+------------+--------------+
| 502| 420| 108|
+------------+------------+--------------+
1 loops, best of 3: 1.62 s per loop
%timeit show(top_df_1, test_df, "key_col", K=100)
+------------+------------+--------------+
|max(key_col)|min(key_col)|count(key_col)|
+------------+------------+--------------+
| 502| 420| 100|
+------------+------------+--------------+
1 loops, best of 3: 252 ms per loop
%timeit show(top_df_2, test_df, "key_col", K=100)
+------------+------------+--------------+
|max(key_col)|min(key_col)|count(key_col)|
+------------+------------+--------------+
| 502| 420| 108|
+------------+------------+--------------+
1 loops, best of 3: 725 ms per loop
(Note that top_df_0
and top_df_2
have 108 entries in the top 100. This is due to the presence of tied entries for the 100th best. The top_df_1
implementation is ignoring the tied entries.).
If you want an exact answer go with top_df_2
(it is about 2x better than top_df_0
). If you want another x2 in performance and are OK with an approximate answer go with top_df_1
.
Upvotes: 8
Reputation: 11955
You should try with head()
instead of limit()
#sample data
df = sc.parallelize([
['123', 'b'], ['666', 'a'],
['345', 'd'], ['555', 'a'],
['456', 'b'], ['444', 'a'],
['678', 'd'], ['333', 'a'],
['135', 'd'], ['234', 'd'],
['987', 'c'], ['987', 'e']
]).toDF(('col1', 'key_col'))
#select top 'n' 'key_col' values from dataframe 'df'
def retrieve_top_n(df, key, n):
return sqlContext.createDataFrame(df.groupBy(key).count().orderBy('count', ascending=False).head(n)).select(key)
retrieve_top_n(df, 'key_col', 3).show()
Hope this helps!
Upvotes: 2
Reputation: 5782
Options:
1) Use pyspark sql row_number within a window function - relevant SO: spark dataframe grouping, sorting, and selecting top rows for a set of columns
2) convert ordered df to rdd and use the top function there (hint: this doesn't appear to actually maintain ordering from my quick test, but YMMV)
Upvotes: 2