TolleWurst
TolleWurst

Reputation: 51

CPS Transformation [Scala]

I am trying to convert some functions into CPS in Scala using examples i found online in other languages.

We only had one lecture on CPS and everything should be really basic but still i don't quite understand how to properly do the transformations.

These are the examples i found online which i'm trying to do in Scala.

pow2' :: Float -> (Float -> a) -> a
pow2' a cont = cont (a ** 2)

add' :: Float -> Float -> (Float -> a) -> a
add' a b cont = cont (a + b)

sqrt' :: Float -> ((Float -> a) -> a)
sqrt' a = \cont -> cont (sqrt a)

pyth' :: Float -> Float -> (Float -> a) -> a
pyth' a b cont = pow2' a (\a2 -> pow2' b (\b2 -> add' a2 b2 (\anb -> sqrt' anb cont)))

Now i started with doing pow2' which looks like this:

def pow2_k(a:Float, k:(Float => Float)) : Float =
    (a*a)

def pow2_cont(n: Float) = pow2_k(n, (x: Float) => x)

At first i had k(a*a) instead of simply (a*a) which led to weird results.

Next i tried add' which at the moment looks like this:

def add_k(a:Float, b:Float, k:(Float, Float => Float)) : Float =
    (a+b)

def add_k_cont(n:Float,m:Float) = add_k(n,m (x: Float, y:Float => (n+m)))

This is obviously wrong. I have trouble writing the correct continuation. Does anyone know a good site, paper, video, etc. that explains CPS transformation and continuations? Most of the ones i found are either too short and without examples or way too complex. I feel like this is not that complicated for these simple functions which i am trying to convert...

Thank you.

Upvotes: 1

Views: 974

Answers (1)

Levi Ramsey
Levi Ramsey

Reputation: 20551

def pow2_k[A](f: Float, k: Float => A): A = k(f * f)

def pow2_cont(n: Float): Float = pow2_k(n, identity)

scala> pow2_cont(2.0f)
res0: Float = 4.0

scala> pow2_k(4.0f, println)
16.0

Probably the simplest way to imagine the CPS transform is along the lines of:

  • Take a function that's not in CPS

    def square(f: Float): Float = f * f
    
  • Make the function generic in a type A; add a second parameter of a function from the result type of your old function to A; and change your function's result type to A

    def cpsSquare[A](f: Float, ret: Float => A): A = ???
    
  • Everywhere your function "returned" a value (note of course that return in Scala has odd semantics so that return is a really bad code smell), pass that value that would be returned to the function (that I called this function ret hints at what's going on)

    def cpsSquare[A](f: Float, ret: Float => A): A = ret(f * f)
    

Basically, in CPS, instead of returning a value, you call a passed function (the continuation) with the value that you'd return.

Upvotes: 2

Related Questions