Dici
Dici

Reputation: 25950

Continuation-passing style in Scala

I have superficially read a couple of blog articles/Wikipedia about continuation-passing style. My high-level goal is to find a systematic technique to make any recursive function (or, if there are restrictions, being aware of them) tail-recursive. However, I have trouble articulating my thoughts and I'm not sure if what my attempts of it make any sense.

For the purpose of the example, I'll propose a simple problem. The goal is, given a sorted list of unique characters, to output all possible words made out of these characters in alphabetical order. For example, sol("op".toList, 3) should return ooo,oop,opo,opp,poo,pop,ppo,ppp.

My recursive solution is the following:

def sol(chars: List[Char], n: Int) = {
    def recSol(n: Int): List[List[Char]] = (chars, n) match {
        case (_  , 0) => List(Nil)
        case (Nil, _) => Nil
        case (_  , _) =>
            val tail = recSol(n - 1)
            chars.map(ch => tail.map(ch :: _)).fold(Nil)(_ ::: _)
    }
    recSol(n).map(_.mkString).mkString(",")
}

I did try to rewrite this by adding a function as a parameter but I did not manage to make something I was convinced to be tail-recursive. I prefer not including my attempt(s) in the question as I'm ashamed of their naiveness, so please excuse me for this.

Therefore the question is basically: how would the function above be written in CPS ?

Upvotes: 2

Views: 2086

Answers (2)

Nathan Davis
Nathan Davis

Reputation: 5766

The first order of business in performing the CPS transform is deciding on a representation for continuations. We can think of continuations as a suspended computation with a "hole". When the hole is filled in with a value, the remainder of the computation can be computed. So functions are a natural choice for representing continuations, at least for toy examples:

type Cont[Hole,Result] = Hole => Result

Here Hole represents the type of the hole that needs to be filled in, and Result represents the type of value the computation ultimately computes.

Now that we have a way to represent continuations, we can worry about the CPS transform itself. Basically, this involves the following steps:

  • The transformation is applied recursively to an expression, stopping at "trivial" expressions / function calls. In this context, "trivial" includes functions defined by Scala (since they are not CPS-transformed, and thus do not have a continuation parameter).
  • We need to add a parameter of type Cont[Return,Result] to each function, where Return is the return type of the untransformed function and Result is the type of the ultimate result of the computation as a whole. This new parameter represents the current continuation. The return type for the transformed function is also changed to Result.
  • Every function call needs to be transformed to accommodate the new continuation parameter. Everything after the call needs to be put into a continuation function, which is then added to the parameter list.

For example, a function:

def f(x : Int) : Int = x + 1

becomes:

def fCps[Result](x : Int)(k : Cont[Int,Result]) : Result = k(x + 1)

and

def g(x : Int) : Int = 2 * f(x)

becomes:

def gCps[Result](x : Int)(k : Cont[Int,Result]) : Result = {
  fCps(x)(y => k(2 * y))
}

Now gCps(5) returns (via currying) a function that represents a partial computation. We can extract the value from this partial computation and use it by supplying a continuation function. For example, we can use the identity function to extract the value unchanged:

gCps(5)(x => x)
// 12

Or, we can print it by using println instead:

gCps(5)(println)
// prints 12

Applying this to your code, we obtain:

def solCps[Result](chars : List[Char], n : Int)(k : Cont[String, Result]) : Result = {
  @scala.annotation.tailrec
  def recSol[Result](n : Int)(k : Cont[List[List[Char]], Result]) : Result = (chars, n) match {
    case (_  , 0) => k(List(Nil))
    case (Nil, _) => k(Nil)
    case (_  , _) =>
      recSol(n - 1)(tail =>
                      k(chars.map(ch => tail.map(ch :: _)).fold(Nil)(_ ::: _)))
  }

  recSol(n)(result =>
              k(result.map(_.mkString).mkString(",")))
}

As you can see, although recSol is now tail-recursive, it comes with the cost of building a more complex continuation at each iteration. So all we've really done is trade space on the JVM's control stack for space on the heap -- the CPS transform does not magically reduce the space complexity of an algorithm.

Also, recSol is only tail-recursive because the recursive call to recSol happens to be the first (non-trivial) expression recSol performs. In general, though, recursive calls would be take place inside a continuation. In the case where there is one recursive call, we can work around that by transforming only calls to the recursive function to CPS. Even so, in general, we would still just be trading stack space for heap space.

Upvotes: 3

Régis Jean-Gilles
Régis Jean-Gilles

Reputation: 32719

Try that:

import scala.annotation.tailrec
def sol(chars: List[Char], n: Int) = {
  @tailrec
  def recSol(n: Int)(cont: (List[List[Char]]) => List[List[Char]]): List[List[Char]] = (chars, n) match {
    case (_  , 0) => cont(List(Nil))
    case (Nil, _) => cont(Nil)
    case (_  , _) =>
      recSol(n-1){ tail =>
        cont(chars.map(ch => tail.map(ch :: _)).fold(Nil)(_ ::: _))
      }
  }
  recSol(n)(identity).map(_.mkString).mkString(",")
}

Upvotes: 3

Related Questions