Jivan
Jivan

Reputation: 23068

Rougly (or partially) sort a list in Scala

Considering a list of several million objects like:

case class Point(val name:String, val x:Double, val y:Double)

I need, for a given Point target, to pick the 10 other points which are closest to the target.

val target = Point("myPoint", 34, 42)
val points = List(...) // list of several million points

def distance(p1: Point, p2: Point) = ??? // return the distance between two points

val closest10 = points.sortWith((a, b) => {
  distance(a, target) < distance(b, target)
}).take(10)

This method works but is very slow. Indeed, the whole list is exhaustively sorted for each target request, whereas past the first 10 closest points, I really don't care about any kind of sorting. I don't even need that the first 10 closest are returned in the correct order.

Ideally, I'd be looking for a "return 10 first and don't pay attention to the rest" kind of method..

Naive solution that I can think of would sound like this: sort by buckets of 1000, take first bucket, sort it by buckets of 100, take first bucket, sort it by buckets of 10, take first bucket, done.

Question is, I guess this must be a very common problem in CS, so before rolling out my own solution based on this naive approach, I'd like to know of any state-of-the-art way of doing that, or even if some standard method already exists.

TL;DR how to get the first 10 items of an unsorted list, without having to sort the whole list?

Upvotes: 3

Views: 624

Answers (4)

Artavazd Balayan
Artavazd Balayan

Reputation: 2413

How it can be done with QuickSelect. I used in-place QuickSelect. Basically, for every target point we calculate the distance between all points and target and use QuickSelect to get k-th smallest distance (k-th order statistic). Will this algo be faster than using sorting depends on factors like number of points, number of nearests and number of targets. In my machine for 3kk random generated points, 10 target points and asking for 10 nearest points, it's 2 times faster than using Sort algo:

Number of points: 3000000
Number of targets: 10
Number of nearest: 10
QuickSelect: 10737 ms.
Sort: 20763 ms.
Results from QuickSelect are valid

Code:

import scala.annotation.tailrec
import scala.concurrent.duration.Deadline
import scala.util.Random

case class Point(val name: String, val x: Double, val y: Double)

class NearestPoints(val points: Seq[Point]) {
  private case class PointWithDistance(p: Point, d: Double) extends Ordered[PointWithDistance] {
    def compare(that: PointWithDistance): Int = d.compareTo(that.d)
  }
  def distance(p1: Point, p2: Point): Double = {
    Math.sqrt(Math.pow(p2.x - p1.x, 2) + Math.pow(p2.y - p1.y, 2))
  }
  def get(target: Point, n: Int): Seq[Point] = {
    val pd = points.map(p => PointWithDistance(p, distance(p, target))).toArray
    (1 to n).map(i => quickselect(i, pd).get.p)
  }
  // In-place QuickSelect from https://gist.github.com/mooreniemi/9e45d55c0410cad0a9eb6d62a5b9b7ae
  def quickselect[T <% Ordered[T]](k: Int, xs: Array[T]): Option[T] = {
    def randint(lo: Int, hi: Int): Int =
      lo + scala.util.Random.nextInt((hi - lo) + 1)

    @inline
    def swap[T](xs: Array[T], i: Int, j: Int): Unit = {
      val t = xs(i)
      xs(i) = xs(j)
      xs(j) = t
    }

    def partition[T <% Ordered[T]](xs: Array[T], l: Int, r: Int): Int = {
      var pivotIndex = randint(l, r)
      val pivotValue = xs(pivotIndex)
      swap(xs, r, pivotIndex)
      pivotIndex = l

      var i = l
      while (i <= r - 1) {
        if (xs(i) < pivotValue) {
          swap(xs, i, pivotIndex)
          pivotIndex = pivotIndex + 1
        }
        i = i + 1
      }
      swap(xs, r, pivotIndex)
      pivotIndex
    }

    @tailrec
    def quickselect0[T <% Ordered[T]](xs: Array[T], l: Int, r: Int, k: Int): T = {
      if (l == r) {
        xs(l)
      } else {
        val pivotIndex = partition(xs, l, r)
        k compare pivotIndex match {
          case 0 => xs(k)
          case -1 => quickselect0(xs, l, pivotIndex - 1, k)
          case 1 => quickselect0(xs, pivotIndex + 1, r, k)
        }
      }
    }

    xs match {
      case _ if xs.isEmpty => None
      case _ if k < 1 || k > xs.length => None
      case _ => Some(quickselect0(xs, 0, xs.size - 1, k - 1))
    }
  }
}

object QuickSelectVsSort {
  def main(args: Array[String]): Unit = {
    val rnd = new Random(42L)
    val MAX_N: Int = 3000000
    val NUM_OF_NEARESTS: Int = 10
    val NUM_OF_TARGETS: Int = 10
    println(s"Number of points: $MAX_N")
    println(s"Number of targets: $NUM_OF_TARGETS")
    println(s"Number of nearest: $NUM_OF_NEARESTS")

    // Generate random points
    val points = (1 to MAX_N)
      .map(x => Point(x.toString, rnd.nextDouble, rnd.nextDouble))

    // Generate target points
    val targets = (1 to NUM_OF_TARGETS).map(x => Point(s"Target$x", rnd.nextDouble, rnd.nextDouble))

    var start = Deadline.now
    val np = new NearestPoints(points)
    val viaQuickSelect = targets.map { case target =>
      val nearest = np.get(target, NUM_OF_NEARESTS)
      nearest
    }
    var end = Deadline.now
    println(s"QuickSelect: ${(end - start).toMillis} ms.")

    start = Deadline.now
    val viaSort = targets.map { case target =>
      val closest = points.sortWith((a, b) => {
        np.distance(a, target) < np.distance(b, target)
      }).take(NUM_OF_NEARESTS)
      closest
    }
    end = Deadline.now
    println(s"Sort: ${(end - start).toMillis} ms.")

    // Validate
    assert(viaQuickSelect.length == viaSort.length)
    viaSort.zipWithIndex.foreach { case (p, idx) =>
      assert(p == viaQuickSelect(idx))
    }
    println("Results from QuickSelect are valid")
  }
}

Upvotes: 1

michid
michid

Reputation: 10814

For finding the top n elements in a list you can Quicksort it and terminate early. That is, terminate at the point where you know there are n elements that are bigger than the pivot. See my implementation in the Rank class of Apache Jackrabbit (in Java though), which does just that.

Upvotes: 0

jwvh
jwvh

Reputation: 51271

I was looking at this and wondered if a PriorityQueue might be useful.

import scala.collection.mutable.PriorityQueue

case class Point(val name:String, val x:Double, val y:Double)
val target = Point("myPoint", 34, 42)
val points = List(...) //list of points

def distance(p1: Point, p2: Point) = ??? //distance between two points

//load points-priority-queue with first 10 points
val ppq = PriorityQueue(points.take(10):_*){
  case (a,b) => distance(a,target) compare distance(b,target) //prioritize points
}

//step through everything after the first 10
points.drop(10).foldLeft(distance(ppq.head,target))((mxDst,nextPnt) => 
  if (mxDst > distance(nextPnt,target)) {
    ppq.dequeue()             //drop current far point
    ppq.enqueue(nextPnt)      //load replacement point
    distance(ppq.head,target) //return new max distance
  } else mxDst)

val result: List[Double] = ppq.dequeueAll  //10 closest points

Upvotes: 1

Leo C
Leo C

Reputation: 22439

Below is a barebone method adapted from this SO answer for picking n smallest integers from a list (which can be enhanced to handle more complex data structure):

def nSmallest(n: Int, list: List[Int]): List[Int] = {
  def update(l: List[Int], e: Int): List[Int] =
    if (e < l.head) (e :: l.tail).sortWith(_ > _) else l

  list.drop(n).foldLeft( list.take(n).sortWith(_ > _) )( update(_, _) )
}

nSmallest( 5, List(3, 2, 8, 2, 9, 1, 5, 5, 9, 1, 7, 3, 4) )
// res1: List[Int] = List(3, 2, 2, 1, 1)

Please note that the output is in reverse order.

Upvotes: 3

Related Questions