user6031489
user6031489

Reputation:

Scala tail recursion

I have a function in scala that I wonder if it's possible to make into a tail recursive function.

def get_f(f: Int => Int, x: Int, y: Int): Int = x match {
  case 0 => y
  case _ => f(get_f(f, x - 1, y))
}

Upvotes: 3

Views: 1129

Answers (5)

adamwy
adamwy

Reputation: 1239

I'll add that you can achieve the same result by using foldLeft on Range, like this:

def get_f(f: Int => Int, x: Int, y: Int) =
  (0 until x).foldLeft(y)((acc, _) => f(acc))

Upvotes: 2

Emiliano Martinez
Emiliano Martinez

Reputation: 4133

In line with previous responses

 def get_f2( f: Int => Int, x: Int, y: Int) : Int = {
   def tail(y: Int, x: Int)(f: Int => Int) : Int = {
     x match {
       case 0 => y
       case _ => tail(f(y), x - 1)(f) : Int
     }
   }

   tail(y, x)(f)
 }

Upvotes: 0

dk14
dk14

Reputation: 22374

Let's start with reducing number of parameters from your non-tailrec version to make it clear what it actually does:

def get_f(f: Int => Int, x: Int, y: Int) = {
  def get_f_impl(x: Int): Int = x match {
    case 0 => y
    case _ => f(get_f_impl(x - 1))
  }
  get_f_impl(x)
}

The idea is that actually you apply f-function x-times to initial value y. So, it becomes clear that you can do something like this in order to make it tail-recursive:

def get_f(f: Int => Int, x: Int, y: Int) = {
  @tailrec def get_f_impl(acc: Int, x: Int): Int = 
    if (x == 0) acc else get_f_impl(f(acc), x - 1) 
  get_f_impl(y, x)
}

REPL investigation:

Your original implementation:

scala> get_f(_ + 1, 4, 0)
res6: Int = 4

Your implementation (with params optimisation):

scala> get_f(_ + 1, 4, 0)
res0: Int = 4

Tailrec implementation:

scala> get_f(_ + 1, 4, 0)
res3: Int = 4

P.S. For more complex cases trampolines might fit: https://espinhogr.github.io/scala/2015/07/12/trampolines-in-scala.html

P.S.2 You can also try:

Upvotes: 2

wheaties
wheaties

Reputation: 35980

It is possible but the way you've constructed it means you're going to have to use a Trampolined style to make it work:

import scala.util.control.TailCalls._

def get_f(f: Int => Int, x: Int, y: Int): TailRec[Int] = x match {
  case 0 => done(y)
  case _ => tailcall(get_f(f, x - 1, y)).map(f)
}

val answer = get_f(_+1, 0, 24).result

You can read about TailRec here or for more advanced study, this paper.

Upvotes: 4

Murat Mustafin
Murat Mustafin

Reputation: 1313

I see that this function applies f function to result recursivly, x times. It's the same as applying it to y, x times. Also I suggest you to use if else instead of pattern matching.

@tailrec
def get_f(f: Int => Int, x: Int, y: Int): Int = 
    if(x == 0) y
    else get_f(f, x - 1, f(y))

Add @tailrec annotation to ensure that it is tail recursive

Upvotes: 7

Related Questions