Shaido
Shaido

Reputation: 28332

Find smallest value in a rolling window partitioned by group

I have a dataframe containing different geographical positions as well as the distance to some other places. My problem is that I want to find the closest n places for each geographical position. My first idea was to use groupBy() followed by some sort of aggregation but I couldn't get that to work.

Instead I tried to first convert the dataframe to an RDD and the use groupByKey(), it works, but the method is cumbersome. Is there is any better alternative to solve this problem? Maybe using groupBy() and aggregate somehow?

A small example of my approach where n=2 with input:

+---+--------+
| id|distance|
+---+--------+
|  1|     5.0|
|  1|     3.0|
|  1|     7.0|
|  1|     4.0|
|  2|     1.0|
|  2|     3.0|
|  2|     3.0|
|  2|     7.0|
+---+--------+

Code:

df.rdd.map{case Row(id: Long, distance: Double) => (id, distance)}
  .groupByKey()
  .map{case (id: Long, iter: Iterable[Double]) => (id, iter.toSeq.sorted.take(2))}
  .toDF("id", "distance")
  .withColumn("distance", explode($"distance"))

Output:

+---+--------+
| id|distance|
+---+--------+
|  1|     3.0|
|  1|     4.0|
|  2|     1.0|
|  2|     3.0|
+---+--------+

Upvotes: 0

Views: 1501

Answers (1)

koiralo
koiralo

Reputation: 23109

You can use Window as below:

val spark = SparkSession.builder().master("local").appName("test").getOrCreate()

import spark.implicits._
case class A(id: Long, distance: Double)
val df = List(A(1, 5.0), A(1,3.0), A(1, 7.0), A(1, 4.0), A(2, 1.0), A(2, 3.0), A(2, 4.0), A(2, 7.0))
  .toDF("id", "distance")

val window = Window.partitionBy("id").orderBy("distance")  
val result = df.withColumn("rank", row_number().over(window)).where(col("rank") <= 2 )

result.drop("rank").show()

You can increase the number of result you want by replacing the 2.

Hope this helps.

Upvotes: 3

Related Questions