Reputation: 399
For a homework assignment I wrote some scala code in which I have the following classes and object (used for modeling a binary tree):
object Tree {
def fold[B](t: Tree, e: B, n: (Int, B, B) => B): B = t match {
case Node(value, l, r) => n(value,fold(l,e,n),fold(r,e,n))
case _ => e
}
def sumTree(t: Tree): Tree =
fold(t, Nil(), (a, b: Tree, c: Tree) => {
val left = b match {
case Node(value, _, _) => value
case _ => 0
}
val right = c match {
case Node(value, _, _) => value
case _ => 0
}
Node(a+left+right,b,c)
})
}
abstract case class Tree
case class Node(value: Int, left: Tree, right: Tree) extends Tree
case class Nil extends Tree
My question is about the sumTree
function which creates a new tree where the nodes have values equal to the sum of the values of its children plus it's own value.
I find it rather ugly looking and I wonder if there is a better way to do this. If I use recursion which works top-down this would be easier, but I could not come up with such a function.
I have to implement the fold
function, with a signature as in the code, to calculate sumTree
I got the feeling this can be implemented in a better way, maybe you have suggestions?
Upvotes: 12
Views: 5080
Reputation: 1934
You've probably turned in your homework already, but I think it's still worth pointing out that the way your code (and the code in other people's answers) looks like is a direct result of how you modeled the binary trees. If, instead of using an algebraic data type (Tree
, Node
, Nil
), you had gone with a recursive type definition, you wouldn't have had to use pattern matching to decompose your binary trees. Here's my definition of a binary tree:
case class Tree[A](value: A, left: Option[Tree[A]], right: Option[Tree[A]])
As you can see there's no need for Node
or Nil
here (the latter is just glorified null
anyway - you don't want anything like this in your code, do you?).
With such definition, fold
is essentially a one-liner:
def fold[A,B](t: Tree[A], z: B)(op: (A, B, B) => B): B =
op(t.value, t.left map (fold(_, z)(op)) getOrElse z, t.right map (fold(_, z)(op)) getOrElse z)
And sumTree
is also short and sweet:
def sumTree(tree: Tree[Int]) = fold(tree, None: Option[Tree[Int]]) { (value, left, right) =>
Some(Tree(value + valueOf(left, 0) + valueOf(right, 0), left , right))
}.get
where valueOf
helper is defined as:
def valueOf[A](ot: Option[Tree[A]], df: A): A = ot map (_.value) getOrElse df
No pattern matching needed anywhere - all because of a nice recursive definition of binary trees.
Upvotes: 1
Reputation: 4385
As Vlad writes, your solution has about the only general shape you can have with such a fold.
Still there is a way to get rid of the node value matching, not only factor it out. And personally I would prefer it that way.
You use match because not every result you get from a recursive fold carries a sum with it. Yes, not every Tree can carry it, Nil has no place for a value, but your fold is not limited to Trees, is it?
So let's have:
case class TreePlus[A](value: A, tree: Tree)
Now we can fold it like this:
def sumTree(t: Tree) = fold[TreePlus[Int]](t, TreePlus(0, Nil), (v, l, r) => {
val sum = v+l.value+r.value
TreePlus(sum, Node(sum, l.tree, r.tree))
}.tree
Of course the TreePlus
is not really needed as we have the canonical product Tuple2
in the standard library.
Upvotes: 2
Reputation: 23502
First of all, I believe and if I may say so, you've done a very good job. I can suggest a couple of slight changes to your code:
abstract class Tree
case class Node(value: Int, left: Tree, right: Tree) extends Tree
case object Nil extends Tree
Nil
is a singleton and best defined as a case-object instead of case-class. Additionally consider qualifying super class Tree
with sealed
. sealed
tells compiler that the class can only be inherited from within the same source file. This lets compiler emit warnings whenever a following match expression is not exhaustive - in other words doesn't include all possible cases.
sealed abstract class Tree
The next couple of improvement could be made to the sumTree
:
def sumTree(t: Tree) = {
// create a helper function to extract Tree value
val nodeValue: Tree=>Int = {
case Node(v,_,_) => v
case _ => 0
}
// parametrise fold with Tree to aid type inference further down the line
fold[Tree](t,Nil,(acc,l,r)=>Node(acc + nodeValue(l) + nodeValue(r) ,l,r))
}
nodeValue
helper function can also be defined as (the alternative notation I used above is possible because a sequence of cases in curly braces is treated as a function literal):
def nodeValue (t:Tree) = t match {
case Node(v,_,_) => v
case _ => 0
}
Next little improvement is parametrising fold
method with Tree
(fold[Tree]
). Because Scala type inferer works through the expression sequentially left-to-right telling it early that we're going to deal with Tree's lets us omit type information when defining function literal which is passed to fold
further on.
So here is the full code including suggestions:
sealed abstract class Tree
case class Node(value: Int, left: Tree, right: Tree) extends Tree
case object Nil extends Tree
object Tree {
def fold[B](t: Tree, e: B, n: (Int, B, B) => B): B = t match {
case Node(value, l, r) => n(value,fold(l,e,n),fold(r,e,n))
case _ => e
}
def sumTree(t: Tree) = {
val nodeValue: Tree=>Int = {
case Node(v,_,_) => v
case _ => 0
}
fold[Tree](t,Nil,(acc,l,r)=>Node(acc + nodeValue(l) + nodeValue(r) ,l,r))
}
}
The recursion you came up with is the only possible direction that lets you traverse the tree and produce a modified copy of the immutable data structure. Any leaf nodes have to be created first before being added to the root, because individual nodes of the tree are immutable and all objects necessary to construct a node have to be known before the construction: leaf nodes need to be created before you can create root node.
Upvotes: 12
Reputation: 20515
Your solution is probably more efficient (certainly uses less stack), but here's a recursive solution, fwiw
def sum( tree:Tree):Tree ={
tree match{
case Nil =>Nil
case Tree(a, b, c) =>val left = sum(b)
val right = sum(c)
Tree(a+total(left)+total(right), left, right)
}
}
def total(tree:Tree):Int = {
tree match{
case Nil => 0
case Tree(a, _, _) =>a
}
Upvotes: 1