joesan
joesan

Reputation: 15435

Scala Tail Recursion From a Flatmap

I have a recursive call as defined below:

def getElems[A](a: A)(f: A => List[A]): List[A] = {
    f(a)
}

def parse[A](depth: Int, elems: List[A], f: A => List[A]): List[A] = {
  elems.flatMap(elem => {
    if (depth > 0) {
      parse(depth - 1, getElems(elem)(f), f)
    } else elems
  })
}

As it can be seen that for every elem in the elems, I run a function that in turn gives me back another List. I do this until I reach the depth 0. So for example., I start with a certain elems and a certain depth like:

parse(depth = 2, elems = List("1", "2"), someFnThatGivesBackAListOfString)

What I'm doing with my code above is that for each element in elems, I check the depth value and if the depth is > 0, I run the function for that elem and go over the same process until I hit a depth of 0. This works as expected, but as it can be seen that it is not stack safe, I'm thiking of getting a tail recursive implementation. To my understanding tail recursion is about reduction, but here it is not the case. So how do I make it stack safe or how can I do a tail recursive logic here?

I started with something like this, but this is not quite right:

def firstAttempt[A](ls: List[A], depthOrig: Int)(f: (A => List[A])): List[A] = {
    @annotation.tailrec
    def helper(acc: List[A], ls: List[A], depth: Int): List[A] =
      ls match {
        case Nil => acc
        case sublist @ (head :: tail) =>
          // Check if the entry is available in the bloom filter
          if (depth > 0)
            helper(acc ::: f(head), tail, depth - 1)
          else
            helper(acc.appended(head), tail, depthOrig)
      }
    helper(Nil, ls, depthOrig)
  }

Upvotes: 1

Views: 173

Answers (1)

jwvh
jwvh

Reputation: 51271

I got this to work by attaching the current depth to each element.

def parse[A](depth:Int, elems:List[A], f:A => List[A]): List[A] = {
  @annotation.tailrec
  def loop(todo:List[(A,Int)], acc:List[A]): List[A] = todo match {
    case Nil => acc
    case (_,dpth)::_ if dpth < 1 =>
      val (zs, td) = todo.span(_._2 < 1)  
      loop(td, acc ++ zs.flatMap(_ => zs.map(_._1)))
    case (elm,dpth)::tl =>
      loop(f(elm).map(_ -> (dpth-1)) ++ tl, acc)
  }
  loop(elems.map(_ -> depth), Nil)
}

Upvotes: 2

Related Questions