Reputation: 23068
Considering a list of several million objects like:
case class Point(val name:String, val x:Double, val y:Double)
I need, for a given Point target, to pick the 10 other points which are closest to the target.
val target = Point("myPoint", 34, 42)
val points = List(...) // list of several million points
def distance(p1: Point, p2: Point) = ??? // return the distance between two points
val closest10 = points.sortWith((a, b) => {
distance(a, target) < distance(b, target)
}).take(10)
This method works but is very slow. Indeed, the whole list is exhaustively sorted for each target request, whereas past the first 10 closest points, I really don't care about any kind of sorting. I don't even need that the first 10 closest are returned in the correct order.
Ideally, I'd be looking for a "return 10 first and don't pay attention to the rest" kind of method..
Naive solution that I can think of would sound like this: sort by buckets of 1000, take first bucket, sort it by buckets of 100, take first bucket, sort it by buckets of 10, take first bucket, done.
Question is, I guess this must be a very common problem in CS, so before rolling out my own solution based on this naive approach, I'd like to know of any state-of-the-art way of doing that, or even if some standard method already exists.
TL;DR how to get the first 10 items of an unsorted list, without having to sort the whole list?
Upvotes: 3
Views: 624
Reputation: 2413
How it can be done with QuickSelect. I used in-place QuickSelect. Basically, for every target point we calculate the distance between all points and target and use QuickSelect to get k-th smallest distance (k-th order statistic). Will this algo be faster than using sorting depends on factors like number of points, number of nearests and number of targets. In my machine for 3kk random generated points, 10 target points and asking for 10 nearest points, it's 2 times faster than using Sort algo:
Number of points: 3000000
Number of targets: 10
Number of nearest: 10
QuickSelect: 10737 ms.
Sort: 20763 ms.
Results from QuickSelect are valid
Code:
import scala.annotation.tailrec
import scala.concurrent.duration.Deadline
import scala.util.Random
case class Point(val name: String, val x: Double, val y: Double)
class NearestPoints(val points: Seq[Point]) {
private case class PointWithDistance(p: Point, d: Double) extends Ordered[PointWithDistance] {
def compare(that: PointWithDistance): Int = d.compareTo(that.d)
}
def distance(p1: Point, p2: Point): Double = {
Math.sqrt(Math.pow(p2.x - p1.x, 2) + Math.pow(p2.y - p1.y, 2))
}
def get(target: Point, n: Int): Seq[Point] = {
val pd = points.map(p => PointWithDistance(p, distance(p, target))).toArray
(1 to n).map(i => quickselect(i, pd).get.p)
}
// In-place QuickSelect from https://gist.github.com/mooreniemi/9e45d55c0410cad0a9eb6d62a5b9b7ae
def quickselect[T <% Ordered[T]](k: Int, xs: Array[T]): Option[T] = {
def randint(lo: Int, hi: Int): Int =
lo + scala.util.Random.nextInt((hi - lo) + 1)
@inline
def swap[T](xs: Array[T], i: Int, j: Int): Unit = {
val t = xs(i)
xs(i) = xs(j)
xs(j) = t
}
def partition[T <% Ordered[T]](xs: Array[T], l: Int, r: Int): Int = {
var pivotIndex = randint(l, r)
val pivotValue = xs(pivotIndex)
swap(xs, r, pivotIndex)
pivotIndex = l
var i = l
while (i <= r - 1) {
if (xs(i) < pivotValue) {
swap(xs, i, pivotIndex)
pivotIndex = pivotIndex + 1
}
i = i + 1
}
swap(xs, r, pivotIndex)
pivotIndex
}
@tailrec
def quickselect0[T <% Ordered[T]](xs: Array[T], l: Int, r: Int, k: Int): T = {
if (l == r) {
xs(l)
} else {
val pivotIndex = partition(xs, l, r)
k compare pivotIndex match {
case 0 => xs(k)
case -1 => quickselect0(xs, l, pivotIndex - 1, k)
case 1 => quickselect0(xs, pivotIndex + 1, r, k)
}
}
}
xs match {
case _ if xs.isEmpty => None
case _ if k < 1 || k > xs.length => None
case _ => Some(quickselect0(xs, 0, xs.size - 1, k - 1))
}
}
}
object QuickSelectVsSort {
def main(args: Array[String]): Unit = {
val rnd = new Random(42L)
val MAX_N: Int = 3000000
val NUM_OF_NEARESTS: Int = 10
val NUM_OF_TARGETS: Int = 10
println(s"Number of points: $MAX_N")
println(s"Number of targets: $NUM_OF_TARGETS")
println(s"Number of nearest: $NUM_OF_NEARESTS")
// Generate random points
val points = (1 to MAX_N)
.map(x => Point(x.toString, rnd.nextDouble, rnd.nextDouble))
// Generate target points
val targets = (1 to NUM_OF_TARGETS).map(x => Point(s"Target$x", rnd.nextDouble, rnd.nextDouble))
var start = Deadline.now
val np = new NearestPoints(points)
val viaQuickSelect = targets.map { case target =>
val nearest = np.get(target, NUM_OF_NEARESTS)
nearest
}
var end = Deadline.now
println(s"QuickSelect: ${(end - start).toMillis} ms.")
start = Deadline.now
val viaSort = targets.map { case target =>
val closest = points.sortWith((a, b) => {
np.distance(a, target) < np.distance(b, target)
}).take(NUM_OF_NEARESTS)
closest
}
end = Deadline.now
println(s"Sort: ${(end - start).toMillis} ms.")
// Validate
assert(viaQuickSelect.length == viaSort.length)
viaSort.zipWithIndex.foreach { case (p, idx) =>
assert(p == viaQuickSelect(idx))
}
println("Results from QuickSelect are valid")
}
}
Upvotes: 1
Reputation: 10814
For finding the top n
elements in a list you can Quicksort it and terminate early. That is, terminate at the point where you know there are n
elements that are bigger than the pivot. See my implementation in the Rank class of Apache Jackrabbit (in Java though), which does just that.
Upvotes: 0
Reputation: 51271
I was looking at this and wondered if a PriorityQueue
might be useful.
import scala.collection.mutable.PriorityQueue
case class Point(val name:String, val x:Double, val y:Double)
val target = Point("myPoint", 34, 42)
val points = List(...) //list of points
def distance(p1: Point, p2: Point) = ??? //distance between two points
//load points-priority-queue with first 10 points
val ppq = PriorityQueue(points.take(10):_*){
case (a,b) => distance(a,target) compare distance(b,target) //prioritize points
}
//step through everything after the first 10
points.drop(10).foldLeft(distance(ppq.head,target))((mxDst,nextPnt) =>
if (mxDst > distance(nextPnt,target)) {
ppq.dequeue() //drop current far point
ppq.enqueue(nextPnt) //load replacement point
distance(ppq.head,target) //return new max distance
} else mxDst)
val result: List[Double] = ppq.dequeueAll //10 closest points
Upvotes: 1
Reputation: 22439
Below is a barebone method adapted from this SO answer for picking n
smallest integers from a list (which can be enhanced to handle more complex data structure):
def nSmallest(n: Int, list: List[Int]): List[Int] = {
def update(l: List[Int], e: Int): List[Int] =
if (e < l.head) (e :: l.tail).sortWith(_ > _) else l
list.drop(n).foldLeft( list.take(n).sortWith(_ > _) )( update(_, _) )
}
nSmallest( 5, List(3, 2, 8, 2, 9, 1, 5, 5, 9, 1, 7, 3, 4) )
// res1: List[Int] = List(3, 2, 2, 1, 1)
Please note that the output is in reverse order.
Upvotes: 3