Reputation: 127
I am having a real tough time with tail recursion...
My current function filters out values less than 'n' from list 'l'
def filter(n: Int, l: List): List = l match {
case Nil => Nil
case hd :: tl => {
if (hd < n) filter(n, tl)
else hd :: filter(n, tl)
}
}
When using large lists, this causes the stack to overflow.
Can someone please help me understand how to convert this to a tail recursive function?
Thanks for any input!
Upvotes: 2
Views: 3023
Reputation: 135217
@Brian's answer is nice but it reverses the input list. That's generally not the intended behaviour.
@jwvh's recommendation is pass the accumulator in a 3rd parameter to the function but that leaks private API to public API.
Either solution would necessitate reversing the accumulator before returning the answer – effectively iterating thru your input list twice. That's an insane implementation, especially considering you're trying to implement this to facilitate large lists.
Consider this tail-recursive implementation which does not expose private API and does not require the accumulator to be reversed after filtering.
disclaimer: this is the first scala procedure I have ever written. Feedback on any implementation style or detail is welcomed.
def filter(n: Int, xs: List[Int]): List[Int] = {
@scala.annotation.tailrec
def aux(k: List[Int] => List[Int], xs: List[Int]): List[Int] = xs match {
case Nil => k(Nil)
case x :: xs if (x < n) => aux(k, xs)
case x :: xs => aux((rest: List[Int]) => k(x :: rest), xs)
}
aux(identity, xs)
}
filter(5, List(1,2,3,4,5,6,7,8,9,0)))
// => List(5, 6, 7, 8, 9)
Upvotes: 3
Reputation: 20285
This is usually done with a helper function that accumulates the results. filterR
has an additional parameter acc
that we add values that are greater than n
to.
def filter(n: Int, l: List[Int]): List[Int] = {
@scala.annotation.tailrec
def filterR(n: Int, l: List[Int], acc: List[Int]): List[Int] = l match {
case Nil => acc
case hd :: tl if(hd < n) => filterR(n, tl, acc)
case hd :: tl => filterR(n, tl, hd :: acc)
}
filterR(n, l, List[Int]())
}
With the suggestion from @jwvh:
@scala.annotation.tailrec
def filter(n: Int, l: List[Int], acc: List[Int] = List[Int]()): List[Int] = l match {
case Nil => acc.reverse
case hd :: tl if(hd < n) => filter(n, tl, acc)
case hd :: tl => filter(n, tl, hd :: acc)
}
Upvotes: 3