user3803714
user3803714

Reputation: 5389

Spark dataframe grouping, sorting, and selecting top rows for a set of columns

I am using Spark 1.5.0. I have a Spark dataframe with following columns:

| user_id | description | fName | weight |

What I would like to do is to select top 10 and bottom 10 rows (based on the value of column weight, which is of datatype Double) per user. How do I do that using Spark SQL or dataframe operations?

For example. for simplicity I am only selecting top 2 rows (based on weight) per user. I would like to sort the o/p on the value of the absolute weight.

u1  desc1 f1  -0.20
u1  desc1 f1  +0.20
u2  desc1 f1  0.80
u2  desc1 f1  -0.60
u1  desc1 f1  1.10
u1  desc1 f1  6.40
u2  desc1 f1  0.05
u1  desc1 f1  -3.20
u2  desc1 f1  0.50
u2  desc1 f1  -0.70
u2  desc1 f1  -0.80   

Here is the desired o/p:

u1  desc1 f1  6.40
u1  desc1 f1  -3.20
u1  desc1 f1  1.10
u1  desc1 f1  -0.20
u2  desc1 f1  0.80
u2  desc1 f1  -0.80
u2  desc1 f1  -0.70
u2  desc1 f1  0.50

Upvotes: 4

Views: 2887

Answers (1)

zero323
zero323

Reputation: 330093

You can use window functions with row_number:

import org.apache.spark.sql.functions.row_number
import org.apache.spark.sql.expressions.Window

val w = Window.partitionBy($"user_id")
val rankAsc = row_number().over(w.orderBy($"weight")).alias("rank_asc")
val rankDesc = row_number().over(w.orderBy($"weight".desc)).alias("rank_desc")

df.select($"*", rankAsc, rankDesc).filter($"rank_asc" <= 2 || $"rank_desc" <= 2)

In Spark 1.5.0 you can use rowNumber instead of row_number.

Upvotes: 2

Related Questions