Nick ten Veen
Nick ten Veen

Reputation: 178

Capturing and chaining types used in for comprehension

I am looking for a way to capture the types that are used during a for comprehension in the type of the comprehension itself. For this i specified a rough interface:

trait Chain[A]{
  type ChainMethod = A => A //type of the method chained so far

  def flatMap[B](f: A => Chain[B]): Chain[B] //the ChainMethod needs to be included in the return type somehow
  def map[B](f: A => B): Chain[B]: Chain[B]

  def fill: ChainMethod //Function has to be uncurried here
}

As an example a few concrete types of Chain:

object StringChain extends Chain[String]
object IntChain extends Chain[Int]

And a case class that will be used:

case class User(name:String, age:Int)

A chain can be created with a for comprehension:

val form = for{
  name <- StringChain
  age <- IntChain
} yield User(name, age)

the type of form should be

Chain[User]{type ChainMethod = String => Int => User}

so that we can do the following:

form.fill("John", 25) //should return User("John", 25)

I tried a few approaches, with structural types and a specialized FlatMappedChain trait, but I cannot get the type system to behave the way I want it to. I would love some ideas or suggestions on how to specify the interface so that the compiler can recognize this if this is possible at all.

Upvotes: 2

Views: 361

Answers (1)

Kolmar
Kolmar

Reputation: 14224

I think it's pretty hard to do this from scratch in Scala. You would likely have to define a lot of implicit classes for functions of different arity.

It becomes easier if you use the shapeless library designed for type-level computations. The following code uses a slightly different approach, where Chain.fill is a function from a tuple of arguments to the result. This implementation of flatMap also allows to combine several forms into one:

import shapeless._
import shapeless.ops.{tuple => tp}

object Chain {
  def of[T]: Chain[Tuple1[T], T] = new Chain[Tuple1[T], T] {
    def fill(a: Tuple1[T]) = a._1
  }
}

/** @tparam A Tuple of arguments for `fill`
  * @tparam O Result of `fill`
  */
abstract class Chain[A, O] { self =>
  def fill(a: A): O

  def flatMap[A2, O2, Len <: Nat, R](next: O => Chain[A2, O2])(
    implicit
      // Append tuple A2 to tuple A to get a single tuple R
      prepend: tp.Prepend.Aux[A, A2, R],
      // Compute length Len of tuple A
      length: tp.Length.Aux[A, Len],
      // Take the first Len elements of tuple R,
      // and assert that they are equivalent to A
      take: tp.Take.Aux[R, Len, A],
      // Drop the first Len elements of tuple R,
      // and assert that the rest are equivalent to A2
      drop: tp.Drop.Aux[R, Len, A2]
  ): Chain[R, O2] = new Chain[R, O2] {
    def fill(r: R): O2 = next(self.fill(take(r))).fill(drop(r))
  }

  def map[O2](f: O => O2): Chain[A, O2] = new Chain[A, O2] {
    def fill(a: A): O2 = f(self.fill(a))
  }
}

And here is how you can use it:

scala> case class Address(country: String, city: String)
defined class Address

scala> case class User(id: Int, name: String, address: Address)
defined class User

scala> val addressForm = for {
  country <- Chain.of[String]
  city <- Chain.of[String]
} yield Address(country, city)
addressForm: com.Main.Chain[this.Out,Address] = com.Main$Chain$$anon$2@3253e213

scala> val userForm = for {
  id <- Chain.of[Int]
  name <- Chain.of[String]
  address <- addressForm
} yield User(id, name, address)
userForm: com.Main.Chain[this.Out,User] = com.Main$Chain$$anon$2@7ad40950

scala> userForm.fill(1, "John", "USA", "New York")
res0: User = User(1,John,Address(USA,New York))

Upvotes: 1

Related Questions