Reputation: 8679
I try to implement tail call optimization to traverse tree-line structure using continuation-passing style in scala. Unfortunately my previous experience with fsharp does not help much. I have recursive call w/o tail optimization:
def traverseTree(tree: Tree)(action: Int => Unit): Unit = {
def traverseTreeRec(tree: Tree, continuation: () => Unit, action: Int => Unit): Unit = tree match {
case Leaf(n) => {
action(n)
continuation()
}
case Node(n1, n2) => {
traverseTreeRec(n1, () => traverseTreeRec(n2, continuation, action), action)
}
}
traverseTreeRec(tree, () => (), action)
}
Afterwards I try to rewrite using @annotation.tailrec
and TailCalls
, but still not sure how to decorate continuation
def traverseTree(tree: Tree)(action: Int => Unit): Unit = {
@annotation.tailrec
def traverseTreeRec(tree: Tree, continuation: () => TailRec[Unit], action: Int => Unit): TailRec[Unit] = tree match {
case Leaf(n) => {
action(n)
continuation()
}
case Node(n1, n2) =>
// how to properly implement tail call here?
// ERROR: it contains a recursive call not in tail position
traverseTreeRec(n1, () => tailcall(traverseTreeRec(n2, continuation, action)), action)
}
traverseTreeRec(tree, () => done(), action)
}
Thanks in advance
Upvotes: 0
Views: 414
Reputation: 8679
Finally, I have an answer from Coursera discussion forum:
def traverseTree(tree: Tree)(action: Int => Unit): Unit = {
def traverseTreeRec(tree: Tree, continuation: () => TailRec[Unit]): TailRec[Unit] = tree match {
case Leaf(n) => {
action(n)
continuation()
}
case Node(n1, n2) =>
tailcall(traverseTreeRec(n1,
() => traverseTreeRec(n2,
() => tailcall(continuation()))))
}
traverseTreeRec(tree, () => done(())).result
}
ps: suggested question by @rob-napier contains some details why it should be applied in this way
Upvotes: 1
Reputation: 48745
I have little knowledge of Scala but in your code I guess you can do:
tailcall(traverseTreeRec(n1, () => tailcall(traverseTreeRec(n2, continuation, action)), action))
Since the tailcall
that were there were in fact in an other anonymous function it will call when hitting a leaf. Perhaps you need it in all tail positions:
def traverseTree(tree: Tree)(action: Int => Unit): Unit = {
@annotation.tailrec
def traverseTreeRec(tree: Tree, continuation: () => TailRec[Unit], action: Int => Unit): TailRec[Unit] = tree match {
case Leaf(n) => {
action(n)
tailcall(continuation())
}
case Node(n1, n2) =>
tailcall(traverseTreeRec(n1, () => tailcall(traverseTreeRec(n2, continuation, action)), action))
}
tailcall(traverseTreeRec(tree, () => done(), action))
}
In Scheme it would have been faster just to let the recursion into the first branch not be a tail call and the second to be one but I guess you can't mix in Scala because it need trampolining to do TCO.
Upvotes: 0