Reputation: 42050
It is pretty easy to write flatten(lol: List[List[T]]): List[T]
which transforms a list of lists to a new list. Other "flat" collections (e.g. Set
) seem to provide flatten
too.
Now I wonder if I can define a flatten
for Tree[T]
(defined as a T
and list of Tree[T]
s).
Upvotes: 1
Views: 2793
Reputation: 6852
If I'm understanding the question correctly, you want to define Tree like so:
case class Tree[T]( value:T, kids:List[Tree[T]] )
First, I wouldn't want to use :::
in the solution because of the performance implications. Second, I'd want to do something much more general -- define a fold operator for the type, which can be used for all sorts of things -- and then simply use a fold to define flatten
:
case class Tree[T]( value:T, kids:List[Tree[T]] ) {
def /:[A]( init:A )( f: (A,T) => A ):A =
( f(init,value) /: kids )( (soFar,kid) => ( soFar /: kid )(f) )
def flatten =
( List.empty[T] /: this )( (soFar,value) => value::soFar ).reverse
}
Here's a test:
scala> val t = Tree( 1, List( Tree( 2, List( Tree(3,Nil), Tree(4,Nil) ) ), Tree(5,Nil), Tree( 6, List( Tree(7,Nil) ) ) ) )
t: Tree[Int] = Tree(1,List(Tree(2,List(Tree(3,List()), Tree(4,List()))), Tree(5,List()), Tree(6,List(Tree(7,List())))))
scala> t.flatten
res15: List[Int] = List(1, 2, 3, 4, 5, 6, 7)
Upvotes: 3
Reputation: 3872
I'm not sure how you want to define that flatten exactly, but you can look at the Scalaz Tree implementation:
https://github.com/scalaz/scalaz/blob/scalaz-seven/core/src/main/scala/scalaz/Tree.scala
If you want flatten to return you list of all Tree nodes, then Scalaz already provides you what you want:
def flatten: Stream[A]
Result type is Stream instead of List, but this is not a problem I guess.
If you want something more sophisticated, then you can probably implement it using existing flatMap
:
def flatMap[B](f: A => Tree[B]): Tree[B]
Let's say you have Tree
of type Tree[Tree[A]]
and want to flatten it to Tree[A]
:
def flatten1: Tree[A] = flatMap(identity)
This will work also for other, more weird scenarios. For example, you can have Tree[List[A]]
, and want to flatten everything inside that Lists
without affecting Tree structure itself:
def flatten2[B]: Tree[List[B]] = flatMap(l => leaf(l.flatten))
Looks like it works as expected:
scala> node(List(List(1)), Stream(node(List(List(2)), Stream(leaf(List(List(3, 4), List(5))))), leaf(List(List(4)))))
res20: scalaz.Tree[List[List[Int]]] = <tree>
scala> res20.flatMap(l => leaf(l.flatten)).drawTree
res23: String =
"List(1)
|
+- List(2)
| |
| `- List(3, 4, 5)
|
`- List(4)
"
It could be worth noting that scalaz Tree is also a Monad. If you will look at the scalaz/tests/src/test/scala/scalaz/TreeTest.scala you will see which laws are fulfilled for Tree:
checkAll("Tree", equal.laws[Tree[Int]])
checkAll("Tree", traverse1.laws[Tree])
checkAll("Tree", applicative.laws[Tree])
checkAll("Tree", comonad.laws[Tree])
I don't know why monad is not here, but if you will add checkAll("Tree", monad.laws[Tree])
and run tests again, they will pass.
Upvotes: 3
Reputation: 16412
This is not perfect, just serves an example. All you need to do is to traverse a tree in depth-first or breadth-first manner and collect results. Pretty much the same as flatten
for lists.
1) Define a tree structure (I know, I know it's not the best way to do it :)):
scala> case class Node[T](value: T, left: Option[Node[T]] = None,
| right: Option[Node[T]] = None)
defined class Node
2) Create a little tree:
scala> val tree = Node(13,
| Some(Node(8,
| Some(Node(1)), Some(Node(11)))),
| Some(Node(17,
| Some(Node(15)), Some(Node(25))))
| )
tree: Node[Int] = Node(13,Some(Node(8,Some(Node(1,None,None)),Some(Node(11,None,None)))),Some(Node(17,Some(Node(15,None,None)),Some(Node(25,None,None)))))
3) Implement a function that can traverse a tree:
scala> def observe[T](node: Node[T], f: Node[T] => Unit): Unit = {
| f(node)
| node.left foreach { observe(_, f) }
| node.right foreach { observe(_, f) }
| }
observe: [T](node: Node[T], f: Node[T] => Unit)Unit
4) Use it to define a function that prints all values:
scala> def printall = observe(tree, (n: Node[_]) => println(n.value))
printall: Unit
5) Finally, define that flatten
function:
scala> def flatten[T](node: Node[T]): List[T] = {
| def flatten[T](node: Option[Node[T]]): List[T] =
| node match {
| case Some(n) =>
| n.value :: flatten(n.left) ::: flatten(n.right)
| case None => Nil
| }
|
| flatten(Some(node))
| }
flatten: [T](node: Node[T])List[T]
6) Let's test. First print all elems:
scala> printall
13
8
1
11
17
15
25
7) Run flatten
:
scala> flatten(tree)
res1: List[Int] = List(13, 8, 1, 11, 17, 15, 25)
It's a sort of general purpose tree algorithm like tree traversal. I made it return T
s instead of Node
s, change it as you like.
Upvotes: 6