Reputation: 65
Just playing with continuations. The goal is to create function which will receive another function as parameter, and execution amount - and return function which will apply parameter given amount times.
The implementation looks pretty obvious
def n_times[T](func:T=>T,count:Int):T=>T = {
@tailrec
def n_times_cont(cnt:Int, continuation:T=>T):T=>T= cnt match {
case _ if cnt < 1 => throw new IllegalArgumentException(s"count was wrong $count")
case 1 => continuation
case _ => n_times_cont(cnt-1,i=>continuation(func(i)))
}
n_times_cont(count, func)
}
def inc (x:Int) = x+1
val res1 = n_times(inc,1000)(1) // Works OK, returns 1001
val res = n_times(inc,10000000)(1) // FAILS
But there is no problem - this code fails with StackOverflow error. Why there is no tail-call optimization here?
I'm running it in Eclipse using Scala plugin, and it returns Exception in thread "main" java.lang.StackOverflowError at scala.runtime.BoxesRunTime.boxToInteger(Unknown Source) at Task_Mult$$anonfun$1.apply(Task_Mult.scala:25) at Task_Mult$$anonfun$n_times_cont$1$1.apply(Task_Mult.scala:18)
p.s.
F# code, which is almost direct translation, is working without any issues
let n_times_cnt func count =
let rec n_times_impl count' continuation =
match count' with
| _ when count'<1 -> failwith "wrong count"
| 1 -> continuation
| _ -> n_times_impl (count'-1) (func >> continuation)
n_times_impl count func
let inc x = x+1
let res = (n_times_cnt inc 10000000) 1
printfn "%o" res
Upvotes: 2
Views: 1546
Reputation: 41646
The Scala standard library has an implementation of trampolines in scala.util.control.TailCalls
. So revisiting your implementation... When you build up the nested calls with continuation(func(t))
, those are tail calls, just not optimized by the compiler. So, let's build up a T => TailRec[T]
, where the stack frames will be replaced with objects in the heap. Then return a function that will take the argument and pass it to that trampolined function:
import util.control.TailCalls._
def n_times_trampolined[T](func: T => T, count: Int): T => T = {
@annotation.tailrec
def n_times_cont(cnt: Int, continuation: T => TailRec[T]): T => TailRec[T] = cnt match {
case _ if cnt < 1 => throw new IllegalArgumentException(s"count was wrong $count")
case 1 => continuation
case _ => n_times_cont(cnt - 1, t => tailcall(continuation(func(t))))
}
val lifted : T => TailRec[T] = t => done(func(t))
t => n_times_cont(count, lifted)(t).result
}
Upvotes: 5
Reputation: 7373
I could be wrong here but I suspect that the n_times_cont
inner function is properly converted to use tail recursion; the culprit's not there.
The stack is blown up by the collected continuation
closures (i.e. the i=>continuation(func(i))
) which make 10000000 nested calls to your inc
method, once you apply the result of the main function.
in fact you can try
scala> val rs = n_times(inc, 1000000)
rs: Int => Int = <function1> //<- we're happy here
scala> rs(1) //<- this blows up the stack!
As an aside, you can rewrite
i=>continuation(func(i))
as
continuation compose func
for the sake of greater readability
Upvotes: 1