Tony K.
Tony K.

Reputation: 5605

Efficient way to fold list in scala, while avoiding allocations and vars

I have a bunch of items in a list, and I need to analyze the content to find out how many of them are "complete". I started out with partition, but then realized that I didn't need to two lists back, so I switched to a fold:

val counts = groupRows.foldLeft( (0,0) )( (pair, row) => 
     if(row.time == 0) (pair._1+1,pair._2) 
     else (pair._1, pair._2+1)
   )

but I have a lot of rows to go through for a lot of parallel users, and it is causing a lot of GC activity (assumption on my part...the GC could be from other things, but I suspect this since I understand it will allocate a new tuple on every item folded).

for the time being, I've rewritten this as

var complete = 0
var incomplete = 0
list.foreach(row => if(row.time != 0) complete += 1 else incomplete += 1)

which fixes the GC, but introduces vars.

I was wondering if there was a way of doing this without using vars while also not abusing the GC?

EDIT:

Hard call on the answers I've received. A var implementation seems to be considerably faster on large lists (like by 40%) than even a tail-recursive optimized version that is more functional but should be equivalent.

The first answer from dhg seems to be on-par with the performance of the tail-recursive one, implying that the size pass is super-efficient...in fact, when optimized it runs very slightly faster than the tail-recursive one on my hardware.

Upvotes: 4

Views: 958

Answers (7)

Tony K.
Tony K.

Reputation: 5605

OK, inspired by the answers above, but really wanting to only pass over the list once and avoid GC, I decided that, in the face of a lack of direct API support, I would add this to my central library code:

class RichList[T](private val theList: List[T]) {
  def partitionCount(f: T => Boolean): (Int, Int) = {
    var matched = 0
    var unmatched = 0
    theList.foreach(r => { if (f(r)) matched += 1 else unmatched += 1 })
    (matched, unmatched)
  }
}

object RichList {
  implicit def apply[T](list: List[T]): RichList[T] = new RichList(list)
}

Then in my application code (if I've imported the implicit), I can write var-free expressions:

val (complete, incomplete) = groupRows.partitionCount(_.time != 0)

and get what I want: an optimized GC-friendly routine that prevents me from polluting the rest of the program with vars.

However, I then saw Luigi's benchmark, and updated it to:

  • Use a longer list so that multiple passes on the list were more obvious in the numbers
  • Use a boolean function in all cases, so that we are comparing things fairly

http://pastebin.com/2XmrnrrB

The var implementation is definitely considerably faster, even though Luigi's routine should be identical (as one would expect with optimized tail recursion). Surprisingly, dhg's dual-pass original is just as fast (slightly faster if compiler optimization is on) as the tail-recursive one. I do not understand why.

Upvotes: 2

som-snytt
som-snytt

Reputation: 39577

How about this one? No import tax.

import scala.collection.generic.CanBuildFrom
import scala.collection.Traversable
import scala.collection.mutable.Builder

case class Count(n: Int, total: Int) {
  def not = total - n
}
object Count {
  implicit def cbf[A]: CanBuildFrom[Traversable[A], Boolean, Count] = new CanBuildFrom[Traversable[A], Boolean, Count] {
    def apply(): Builder[Boolean, Count] = new Counter
    def apply(from: Traversable[A]): Builder[Boolean, Count] = apply()
  }
}
class Counter extends Builder[Boolean, Count] {
  var n = 0
  var ttl = 0
  override def +=(b: Boolean) = { if (b) n += 1; ttl += 1; this }
  override def clear() { n = 0 ; ttl = 0 }
  override def result = Count(n, ttl)
}

object Counting extends App {
  val vs = List(4, 17, 12, 21, 9, 24, 11)
  val res: Count = vs map (_ % 2 == 0)
  Console println s"${vs} have ${res.n} evens out of ${res.total}; ${res.not} were odd."
  val res2: Count = vs collect { case i if i % 2 == 0 => i > 10 }
  Console println s"${vs} have ${res2.n} evens over 10 out of ${res2.total}; ${res2.not} were smaller."
}

Upvotes: 2

dhg
dhg

Reputation: 52681

The cleanest two-pass solution is probably to just use the built-in count method:

val complete = groupRows.count(_.time == 0)
val counts = (complete, groupRows.size - complete)

But you can do it in one pass if you use partition on an iterator:

val (complete, incomplete) = groupRows.iterator.partition(_.time == 0)
val counts = (complete.size, incomplete.size)

This works because the new returned iterators are linked behind the scenes and calling next on one will cause it to move the original iterator forward until it finds a matching element, but it remembers the non-matching elements for the other iterator so that they don't need to be recomputed.


Example of the one-pass solution:

scala> val groupRows = List(Row(0), Row(1), Row(1), Row(0), Row(0)).view.map{x => println(x); x}
scala> val (complete, incomplete) = groupRows.iterator.partition(_.time == 0)
Row(0)
Row(1)
complete: Iterator[Row] = non-empty iterator
incomplete: Iterator[Row] = non-empty iterator
scala> val counts = (complete.size, incomplete.size)
Row(1)
Row(0)
Row(0)
counts: (Int, Int) = (3,2)

Upvotes: 11

Luigi Plinge
Luigi Plinge

Reputation: 51109

I see you've already accepted an answer, but you rightly mention that that solution will traverse the list twice. The way to do it efficiently is with recursion.

def counts(xs: List[...], complete: Int = 0, incomplete: Int = 0): (Int,Int) = 
  xs match {
    case Nil => (complete, incomplete)
    case row :: tail => 
      if (row.time == 0) counts(tail, complete + 1, incomplete)
      else               counts(tail, complete, incomplete + 1)
  }

This is effectively just a customized fold, except we use 2 accumulators which are just Ints (primitives) instead of tuples (reference types). It should also be just as efficient a while-loop with vars - in fact, the bytecode should be identical.

Upvotes: 3

Dave Griffith
Dave Griffith

Reputation: 20515

Maybe it's just me, but I prefer using the various specialized folds (.size, .exists, .sum, .product) if they are available. I find it clearer and less error-prone than the heavy-duty power of general folds.

val complete = groupRows.view.filter(_.time==0).size
(complete, groupRows.length - complete)

Upvotes: 2

Dominic Bou-Samra
Dominic Bou-Samra

Reputation: 15416

You could just calculate it using the difference like so:

def counts(groupRows: List[Row]) = {
  val complete = groupRows.foldLeft(0){ (pair, row) => 
    if(row.time == 0) pair + 1 else pair
  }
  (complete, groupRows.length - complete)
}

Upvotes: 0

Rex Kerr
Rex Kerr

Reputation: 167891

It is slightly tidier to use a mutable accumulator pattern, like so, especially if you can re-use your accumulator:

case class Accum(var complete = 0, var incomplete = 0) {
  def inc(compl: Boolean): this.type = {
    if (compl) complete += 1 else incomplete += 1
    this
  }
}
val counts = groupRows.foldLeft( Accum() ){ (a, row) => a.inc( row.time == 0 ) }

If you really want to, you can hide your vars as private; if not, you still are a lot more self-contained than the pattern with vars.

Upvotes: 0

Related Questions