Reputation: 1064
I am trying to understand how various recursive functions can be converted to tail recursive. I've looked through the many examples of both fibonacci and factorial conversions to tail recursive and understand those but am having a tough time making the leap to a problem with a somewhat different structure. An example is:
def countSteps(n: Int): Int = {
if(n<0) return 0
if(n==0) return 1
countSteps(n-1) + countSteps(n-2) + countSteps(n-3)
}
How would you convert this to a tail recursive implementation?
I've looked through similar questions such as: Convert normal recursion to tail recursion But again these don't seem to translate to this problem.
Upvotes: 7
Views: 2777
Reputation: 14404
The trick to turning your function, or any function that requires multiple calls to itself into a tail recursive function is to break down the multiple calls into a list of arguments that can be recursively consumed and applied:
def countSteps(n: Int): Int = {
def countSteps2(steps: List[Int], acc: Int): Int =
steps match {
case Nil => acc
case head :: tail =>
val (n, s) = if (head < 0) (0, Nil)
else if (head == 0) (1, Nil)
else (0, List(head - 1, head - 2, head - 3))
countSteps2(s ::: tail, n + acc)
}
countSteps2(List(n),0)
}
The inner function countSteps2
no longer takes a single argument but a list of arguments and an accumulator. This way we can calculate the value for the head of the arguments or generate a new list of arguments than can be added to the existing sequence to have countSteps2
recurse on.
Each time we do have an input, we take the head
and compute either a 0, 1 or additional list of arguments. Now we can just recurse on countSteps2
again with the new list of arguments prepended to the existing tail
and any value we did calculate added to the accumulator, acc
. Since the only call to countSteps2
is an exit condition, the recursion is a tail recursion
We can finally exit when all inputs have been consumed, at which time acc
has the results of all the recursive steps summed in it.
Upvotes: 3
Reputation: 51501
Some things just are not tail recursive, and any attempt to transform them will end up building the stack manually.
In this case, however, we can accumulate (untested, don't have scala):
def countSteps(n: Int): Int = {
if (n < 0) return 0
countStepsAux(n, 0, 1, 0, 0)
}
def countStepsAux(n: Int, step: Int, n1: Int, n2: Int, n3: Int): Int = {
if (n == step) return n1
countStepsAux(n, step + 1, n1 + n2 + n3, n1, n2)
}
Upvotes: 6