bsky
bsky

Reputation: 20222

Failure to add keys to map in parallel

I have the following code:

  var res: GenMap[Point, GenSeq[Point]] = points.par.groupBy(point => findClosest(point, means))
  means.par.foreach(mean => if(!res.contains(mean)) {
    println("Map doesn't contain mean: " + mean)
    res += mean -> GenSeq.empty[Point]
    println("Map contains?: " + res.contains(mean))
  })

That uses this case class:

case class Point(val x: Double, val y: Double, val z: Double)

Basically, the code groups the Point elements in points around the Point elements in means. The algorithm itself is not very important though.

My problem is that I am getting the following output:

Map doesn't contain mean: (0.44, 0.59, 0.73)
Map doesn't contain mean: (0.44, 0.59, 0.73)
Map doesn't contain mean: (0.1, 0.11, 0.11)
Map doesn't contain mean: (0.1, 0.11, 0.11)
Map contains?: true
Map contains?: true
Map contains?: false
Map contains?: true

Why would I ever get this?

Map contains?: false

I am checking if a key is in the res map. If it is not, then I'm adding it. So how can it not be present in the map?

Is there an issue with parallelism?

Upvotes: 0

Views: 184

Answers (2)

richj
richj

Reputation: 7529

Processing a point changes means and the result is sensitive to processing order, so the algorithm doesn't lend itself to parallel execution. If parallel execution is important enough to allow a change of algorithm, then it might be possible to find an algorithm that can be applied in parallel.

Using a known set of grouping points, such as grid square centres, means that the points can be allocated to their grouping points in parallel and grouped by their grouping points in parallel:

import scala.annotation.tailrec
import scala.collection.parallel.ParMap
import scala.collection.{GenMap, GenSeq, Map}
import scala.math._
import scala.util.Random

class ParallelPoint {
  val rng = new Random(0)

  val groups: Map[Point, Point] = (for {
                i <- 0 to 100
                j <- 0 to 100
                k <- 0 to 100
              }
              yield {
                val p = Point(10.0 * i, 10.0 * j, 10.0 * k)
                p -> p
              }
    ).toMap

  val points: Array[Point] = (1 to 10000000).map(aaa => Point(rng.nextDouble() * 1000.0, rng.nextDouble() * 1000.0, rng.nextDouble() * 1000.0)).toArray

  def findClosest(point: Point, groups: GenMap[Point, Point]): (Point, Point) = {
    val x: Double = rint(point.x / 10.0) * 10.0
    val y: Double = rint(point.y / 10.0) * 10.0
    val z: Double = rint(point.z / 10.0) * 10.0

    val mean: Point = groups(Point(x, y, z)) //.getOrElse(throw new Exception(s"$point out of range of mean ($x, $y, $z).") )

    (mean, point)
  }

  @tailrec
  private def total(points: GenSeq[Point]): Option[Point] = {
    points.size match {
      case 0 => None
      case 1 => Some(points(0))
      case _ => total((points(0) + points(1)) +: points.drop(2))
    }
  }

  def mean(points: GenSeq[Point]): Option[Point] = {
    total(points) match {
      case None => None
      case Some(p) => Some(p / points.size)
    }
  }

  val startTime = System.currentTimeMillis()

  println("starting test ...")

  val res: ParMap[Point, GenSeq[Point]] = points.par.map(p => findClosest(p, groups)).groupBy(pp => pp._1).map(kv => kv._1 -> kv._2.map(v => v._2))

  val groupTime = System.currentTimeMillis()
  println(s"... grouped result after ${groupTime - startTime}ms ...")

  points.par.foreach(p => if (! res(findClosest(p, groups)._1).exists(_ == p)) println(s"point $p not found"))

  val checkTime = System.currentTimeMillis()

  println(s"... checked grouped result after ${checkTime - startTime}ms ...")

  val means: ParMap[Point, GenSeq[Point]] = res.map{ kv => mean(kv._2).get -> kv._2 }

  val meansTime = System.currentTimeMillis()

  println(s"... means calculated after ${meansTime - startTime}ms.")
}

object ParallelPoint {
  def main(args: Array[String]): Unit = new ParallelPoint()
}

case class Point(x: Double, y: Double, z: Double) {
  def +(that: Point): Point = {
      Point(this.x + that.x, this.y + that.y, this.z + that.z)
  }

  def /(scale: Double): Point = Point(x/ scale, y / scale, z / scale)
}

The last step replaces the grouping point with the calculated mean of the grouped points as the map key. This processes 10 million points in about 30 seconds on my 2011 MBP.

Upvotes: 0

Mikel San Vicente
Mikel San Vicente

Reputation: 3863

Your code has a race condition in line

res += mean -> GenSeq.empty[Point]

more than one thread is reasigning res concurrently so some entries can be missed.

This code solves the problem:

val closest = points.par.groupBy(point => findClosest(point, means))
val res = means.foldLeft(closest) {
  case (map, mean) =>
    if(map.contains(mean))
      map
    else
      map + (mean -> GenSeq.empty[Point])
}

Upvotes: 2

Related Questions