Reputation: 51
I am trying to convert some functions into CPS in Scala using examples i found online in other languages.
We only had one lecture on CPS and everything should be really basic but still i don't quite understand how to properly do the transformations.
These are the examples i found online which i'm trying to do in Scala.
pow2' :: Float -> (Float -> a) -> a
pow2' a cont = cont (a ** 2)
add' :: Float -> Float -> (Float -> a) -> a
add' a b cont = cont (a + b)
sqrt' :: Float -> ((Float -> a) -> a)
sqrt' a = \cont -> cont (sqrt a)
pyth' :: Float -> Float -> (Float -> a) -> a
pyth' a b cont = pow2' a (\a2 -> pow2' b (\b2 -> add' a2 b2 (\anb -> sqrt' anb cont)))
Now i started with doing pow2' which looks like this:
def pow2_k(a:Float, k:(Float => Float)) : Float =
(a*a)
def pow2_cont(n: Float) = pow2_k(n, (x: Float) => x)
At first i had k(a*a)
instead of simply (a*a)
which led to weird results.
Next i tried add' which at the moment looks like this:
def add_k(a:Float, b:Float, k:(Float, Float => Float)) : Float =
(a+b)
def add_k_cont(n:Float,m:Float) = add_k(n,m (x: Float, y:Float => (n+m)))
This is obviously wrong. I have trouble writing the correct continuation. Does anyone know a good site, paper, video, etc. that explains CPS transformation and continuations? Most of the ones i found are either too short and without examples or way too complex. I feel like this is not that complicated for these simple functions which i am trying to convert...
Thank you.
Upvotes: 1
Views: 974
Reputation: 20551
def pow2_k[A](f: Float, k: Float => A): A = k(f * f)
def pow2_cont(n: Float): Float = pow2_k(n, identity)
scala> pow2_cont(2.0f)
res0: Float = 4.0
scala> pow2_k(4.0f, println)
16.0
Probably the simplest way to imagine the CPS transform is along the lines of:
Take a function that's not in CPS
def square(f: Float): Float = f * f
Make the function generic in a type A
; add a second parameter of a function from the result type of your old function to A
; and change your function's result type to A
def cpsSquare[A](f: Float, ret: Float => A): A = ???
Everywhere your function "returned" a value (note of course that return
in Scala has odd semantics so that return
is a really bad code smell), pass that value that would be returned to the function (that I called this function ret
hints at what's going on)
def cpsSquare[A](f: Float, ret: Float => A): A = ret(f * f)
Basically, in CPS, instead of returning a value, you call a passed function (the continuation) with the value that you'd return.
Upvotes: 2