Michael Korbakov
Michael Korbakov

Reputation: 2177

How to make tree implemented in Scala useful with higher-order collection functions?

I have a simple tree structure in Scala implemented like this:

sealed abstract class FactsQueryAst[FactType] {
}

object FactsQueryAst {
  case class AndNode[FactType](subqueries: Seq[FactsQueryAst[FactType]]) extends FactsQueryAst[FactType]
  case class OrNode[FactType](subqueries: Seq[FactsQueryAst[FactType]]) extends FactsQueryAst[FactType]
  case class Condition[FactType](fact: FactType, value: FactValue) extends FactsQueryAst[FactType]
}

Are there any relatively simple ways to make this structure work with higher-order functions like map, foldLeft or filter? There's good article about implementing Traversable trait for your own collections (http://daily-scala.blogspot.com/2010/04/creating-custom-traversable.html), but it seems to be overcomplicated for the tree case, or at least I'm missing something principal.

UPD. I tried to implement naive Traversable as below, but it results in infinite loop just for printing the value.

sealed abstract class FactsQueryAst[FactType] extends Traversable[FactsQueryAst.Condition[FactType]]

object FactsQueryAst {
  case class AndNode[FactType](subqueries: Seq[FactsQueryAst[FactType]]) extends FactsQueryAst[FactType] {
    def foreach[U](f: (Condition[FactType]) => U) {subqueries foreach {_.foreach(f)}}
  }
  case class OrNode[FactType](subqueries: Seq[FactsQueryAst[FactType]]) extends FactsQueryAst[FactType] {
    def foreach[U](f: (Condition[FactType]) => U) {subqueries foreach {_.foreach(f)}}
  }
  case class Condition[FactType](fact: FactType, value: FactValue) extends FactsQueryAst[FactType]{
    def foreach[U](f: (Condition[FactType]) => U) {f(this)}
  }
}

The stack trace for infinite loop looks like this:

at tellmemore.queries.FactsQueryAst$Condition.stringPrefix(FactsQueryAst.scala:65532)
at scala.collection.TraversableLike$class.toString(TraversableLike.scala:639)
at tellmemore.queries.FactsQueryAst.toString(FactsQueryAst.scala:5)
at java.lang.String.valueOf(String.java:2854)
at scala.collection.mutable.StringBuilder.append(StringBuilder.scala:197)
at scala.collection.TraversableOnce$$anonfun$addString$1.apply(TraversableOnce.scala:322)
at tellmemore.queries.FactsQueryAst$Condition.foreach(FactsQueryAst.scala:23)
at scala.collection.TraversableOnce$class.addString(TraversableOnce.scala:320)
at tellmemore.queries.FactsQueryAst.addString(FactsQueryAst.scala:5)
at scala.collection.TraversableOnce$class.mkString(TraversableOnce.scala:286)
at tellmemore.queries.FactsQueryAst.mkString(FactsQueryAst.scala:5)
at scala.collection.TraversableLike$class.toString(TraversableLike.scala:639)
at tellmemore.queries.FactsQueryAst.toString(FactsQueryAst.scala:5)
at java.lang.String.valueOf(String.java:2854)
at scala.collection.mutable.StringBuilder.append(StringBuilder.scala:197)
at scala.collection.TraversableOnce$$anonfun$addString$1.apply(TraversableOnce.scala:322)
at tellmemore.queries.FactsQueryAst$Condition.foreach(FactsQueryAst.scala:23)

Upvotes: 3

Views: 1053

Answers (3)

Felix
Felix

Reputation: 8495

There is a good chance that you need to look into the traversable-builder pattern. It is not so straight forward, but it preserves structure between operations using a factory-like pattern. I recommend you looking into this:

http://docs.scala-lang.org/overviews/core/architecture-of-scala-collections.html

The pattern needs an implicit builder object to be implemented in your companion object as well as an implementation of some traversable (e.g. implementing some trait with foreach and possibly some higher order methods)

I hope this is useful to you :)

Upvotes: 1

huynhjl
huynhjl

Reputation: 41646

I am going to post my answer on how to preserve structure in this separate answer since the other one is getting too long (and the question came into your later comment). Although, I suspect using Kiama is a better choice, here is way to preserve your structure and have operations similar to foldLeft, map and filter. I have made both FactType and FactValue type parameters so that I can give shorter examples with Int and String and also it is more general.

The idea is that your tree structure is a recursive structure using 3 constructors: AndNode and OrNode takes sequences and Condition takes two arguments. I define a fold function that recursively transform this structure into another type R by requiring 3 functions, one for each of the constructors:

sealed abstract class FactsQueryAst[T, V] {
  import FactsQueryAst._
  def fold[R](fAnd: Seq[R] => R, fOr: Seq[R] => R, fCond: (T, V) => R): R = this match {
    case AndNode(seq) => fAnd(seq.map(_.fold(fAnd, fOr, fCond)))
    case OrNode(seq) => fOr(seq.map(_.fold(fAnd, fOr, fCond)))
    case Condition(t, v) => fCond(t, v)
  }
  def mapConditions[U, W](f: (T, V) => (U, W)) =
    fold[FactsQueryAst[U, W]](
      AndNode(_),
      OrNode(_),
      (t, v) => {val uw = f(t, v); Condition(uw._1, uw._2) })
}

object FactsQueryAst {
  case class AndNode[T, V](subqueries: Seq[FactsQueryAst[T, V]]) extends FactsQueryAst[T, V]
  case class OrNode[T, V](subqueries: Seq[FactsQueryAst[T, V]]) extends FactsQueryAst[T, V]
  case class Condition[T, V](factType: T, value: V) extends FactsQueryAst[T, V]
}

mapConditions is implemented in terms of fold. But also a bunch of other functions. Here it is in action:

object so {
  import FactsQueryAst._
  val ast =
    OrNode(
      Seq(
        AndNode(
          Seq(Condition(1, "one"))),
        AndNode(
          Seq(Condition(3, "three")))))
  //> ast  : worksheets.FactsQueryAst.OrNode[Int,String] =
  //    OrNode(List(AndNode(List(Condition(1,one))), 
  //                AndNode(List(Condition(3,three)))))

  val doubled = ast.mapConditions{case (t, v) => (t*2, v*2) }

  //> doubled  : worksheets.FactsQueryAst[Int,String] = 
  //    OrNode(List(AndNode(List(Condition(2,oneone))), 
  //                AndNode(List(Condition(6,threethree)))))

  val sums = ast.fold[(Int, String)](
    seq => seq.reduceLeft((a, b) => (a._1 + b._1, a._2 + b._2)),
    seq => seq.reduceLeft((a, b) => (a._1 + b._1, a._2 + b._2)),
    (t, v) => (t, v))
  //> sums  : (Int, String) = (4,onethree)

  val andOrSwitch = ast.fold[FactsQueryAst[Int, String]](
    OrNode(_),
    AndNode(_),
    (t, v) => Condition(t, v))
  //> andOrSwitch  : worksheets.FactsQueryAst[Int,String] = 
  //    AndNode(List(OrNode(List(Condition(1,one))), 
  //                 OrNode(List(Condition(3,three)))))

  val asList = ast.fold[List[(Int, String)]](
    _.reduceRight(_ ::: _),
    _.reduceRight(_ ::: _),
    (t, v) => List((t, v)))
  //> asList  : List[(Int, String)] = List((1,one), (3,three))  
}

Upvotes: 1

huynhjl
huynhjl

Reputation: 41646

Let's rename FactType to something that looks more like a type parameter. I think naming it just T helps indicate it is a type parameter versus a meaningful class in your code:

sealed abstract class FactsQueryAst[T] extends Traversable[T]

So FactQueryAst contains things of type T and we want to be able to traverse the tree to do something for each t:T. The method to implement is:

def foreach[U](f: T => U): Unit

So replacing all FactType in your code with T and modifying the signature of T, I end up with:

object FactsQueryAst {
  case class AndNode[T](subqueries: Seq[FactsQueryAst[T]]) extends FactsQueryAst[T] {
    def foreach[U](f: T => U) { subqueries foreach { _.foreach(f) } }
  }
  case class OrNode[T](subqueries: Seq[FactsQueryAst[T]]) extends FactsQueryAst[T] {
    def foreach[U](f: T => U) { subqueries foreach { _.foreach(f) } }
  }
  case class Condition[T](factType: T, value: FactValue) extends FactsQueryAst[T] {
    def foreach[U](f: T => U) { f(factType) }
  }
}

This works like this:

import FactsQueryAst._
case class FactValue(v: String)
val t =
  OrNode(
    Seq(
      AndNode(
        Seq(Condition(1, FactValue("one")), Condition(2, FactValue("two")))),
      AndNode(
        Seq(Condition(3, FactValue("three"))))))
//> t  : worksheets.FactsQueryAst.OrNode[Int] = FactsQueryAst(1, 2, 3)
t.map(i => i + 1)
//> res0: Traversable[Int] = List(2, 3, 4)

Obviously implementing traversable loses the structure when you map over but that may be enough for your use case. If you have something more specific that you need, you can ask another question.

Edit:

It turns out your initial version would probably work. Here is an almost identical version, but notice I override toString in Condition. I suspect if you override toString() your version will work too:

case class FactValue(v: String)
case class FactType(t: Int)

sealed abstract class FactsQueryAst extends Traversable[FactsQueryAst.Condition]

object FactsQueryAst {
  case class AndNode(subqueries: Seq[FactsQueryAst]) extends FactsQueryAst {
    def foreach[U](f: Condition => U) { subqueries foreach { _.foreach(f) } }
  }
  case class OrNode(subqueries: Seq[FactsQueryAst]) extends FactsQueryAst {
    def foreach[U](f: Condition => U) { subqueries foreach { _.foreach(f) } }
  }
  case class Condition(factType: FactType, value: FactValue)  extends FactsQueryAst {
    def foreach[U](f: Condition => U) { f(this) }
    override def toString() = s"Cond($factType, $value)"
  }
}

The infinite recursion happens when trying to print the object; it is most likely due to TraversableLike seeing that Condition is a Traversable calls mkString which calls addString which calls foreach and then things go into a loop.

Upvotes: 6

Related Questions