Reputation: 28332
I have a dataframe containing different geographical positions as well as the distance to some other places. My problem is that I want to find the closest n places for each geographical position. My first idea was to use groupBy()
followed by some sort of aggregation but I couldn't get that to work.
Instead I tried to first convert the dataframe to an RDD
and the use groupByKey()
, it works, but the method is cumbersome. Is there is any better alternative to solve this problem? Maybe using groupBy()
and aggregate somehow?
A small example of my approach where n=2
with input:
+---+--------+
| id|distance|
+---+--------+
| 1| 5.0|
| 1| 3.0|
| 1| 7.0|
| 1| 4.0|
| 2| 1.0|
| 2| 3.0|
| 2| 3.0|
| 2| 7.0|
+---+--------+
Code:
df.rdd.map{case Row(id: Long, distance: Double) => (id, distance)}
.groupByKey()
.map{case (id: Long, iter: Iterable[Double]) => (id, iter.toSeq.sorted.take(2))}
.toDF("id", "distance")
.withColumn("distance", explode($"distance"))
Output:
+---+--------+
| id|distance|
+---+--------+
| 1| 3.0|
| 1| 4.0|
| 2| 1.0|
| 2| 3.0|
+---+--------+
Upvotes: 0
Views: 1501
Reputation: 23109
You can use Window as below:
val spark = SparkSession.builder().master("local").appName("test").getOrCreate()
import spark.implicits._
case class A(id: Long, distance: Double)
val df = List(A(1, 5.0), A(1,3.0), A(1, 7.0), A(1, 4.0), A(2, 1.0), A(2, 3.0), A(2, 4.0), A(2, 7.0))
.toDF("id", "distance")
val window = Window.partitionBy("id").orderBy("distance")
val result = df.withColumn("rank", row_number().over(window)).where(col("rank") <= 2 )
result.drop("rank").show()
You can increase the number of result you want by replacing the 2.
Hope this helps.
Upvotes: 3