Michael Lafayette
Michael Lafayette

Reputation: 3072

How to make your own for-comprehension compliant scala monad?

I want to implement my own for-comprehension compatible monads and functors in Scala.

Let's take two stupid monads as an example. One monad is a state monad that contains an "Int" that you can map or flatmap over.

val maybe = IntMonad(5)
maybe flatMap( a => 3 * ( a map ( () => 2 * a ) ) )
// returns IntMonad(30)

Another monad takes does function composition like so...

val func = FunctionMonad( () => println("foo") )
val fooBar = func map ( () => println("bar") )
fooBar()
// foo
// bar
// returns Unit

The example may have some mistakes, but you get the idea.

I want to be able to use these two different types of made up Monads inside a for-comprehension in Scala. Like this:

val myMonad = IntMonad(5)
for {
    a <- myMonad
    b <- a*2
    c <- IntMonad(b*2)
} yield c    
// returns IntMonad(20)

I am not a Scala master, but you get the idea

Upvotes: 14

Views: 3667

Answers (1)

Michael Zajac
Michael Zajac

Reputation: 55569

For a type to be used within a for-comprehension, you really only need to define map and flatMap methods for it that return instances of the same type. Syntactically, the for-comprehension is transformed by the compiler into a series of flatMaps followed by a final map for the yield. As long as these methods are available with the appropriate signature, it will work.

I'm not really sure what you're after with your examples, but here is a trivial example that is equivalent to Option:

sealed trait MaybeInt {
    def map(f: Int => Int): MaybeInt
    def flatMap(f: Int => MaybeInt): MaybeInt
}

case class SomeInt(i: Int) extends MaybeInt {
    def map(f: Int => Int): MaybeInt = SomeInt(f(i))
    def flatMap(f: Int => MaybeInt): MaybeInt = f(i)
}

case object NoInt extends MaybeInt {
    def map(f: Int => Int): MaybeInt = NoInt
    def flatMap(f: Int => MaybeInt): MaybeInt = NoInt
}

I have a common trait with two sub-types (I could have as many as I wanted, though). The common trait MaybeInt enforces each sub-type to conform to the map/flatMap interface.

scala> val maybe = SomeInt(1)
maybe: SomeInt = SomeInt(1)

scala> val no = NoInt
no: NoInt.type = NoInt

for {
  a <- maybe
  b <- no
} yield a + b

res10: MaybeInt = NoInt

for {
  a <- maybe
  b <- maybe
} yield a + b

res12: MaybeInt = SomeInt(2)

Additionally, you can add foreach and filter. If you want to also handle this (no yield):

for(a <- maybe) println(a)

You would add foreach. And if you want to use if guards:

for(a <- maybe if a > 2) yield a

You would need filter or withFilter.

A full example:

sealed trait MaybeInt { self =>
    def map(f: Int => Int): MaybeInt
    def flatMap(f: Int => MaybeInt): MaybeInt
    def filter(f: Int => Boolean): MaybeInt
    def foreach[U](f: Int => U): Unit
    def withFilter(p: Int => Boolean): WithFilter = new WithFilter(p)

    // Based on Option#withFilter
    class WithFilter(p: Int => Boolean) {
        def map(f: Int => Int): MaybeInt = self filter p map f
        def flatMap(f: Int => MaybeInt): MaybeInt = self filter p flatMap f
        def foreach[U](f: Int => U): Unit = self filter p foreach f
        def withFilter(q: Int => Boolean): WithFilter = new WithFilter(x => p(x) && q(x))
    }
}

case class SomeInt(i: Int) extends MaybeInt {
    def map(f: Int => Int): MaybeInt = SomeInt(f(i))
    def flatMap(f: Int => MaybeInt): MaybeInt = f(i)
    def filter(f: Int => Boolean): MaybeInt = if(f(i)) this else NoInt
    def foreach[U](f: Int => U): Unit = f(i)
}

case object NoInt extends MaybeInt {
    def map(f: Int => Int): MaybeInt = NoInt
    def flatMap(f: Int => MaybeInt): MaybeInt = NoInt
    def filter(f: Int => Boolean): MaybeInt = NoInt
    def foreach[U](f: Int => U): Unit = ()
}

Upvotes: 24

Related Questions