tamasgal
tamasgal

Reputation: 26329

Flattening a tree with parents/children and return all nodes

It's probably too late, but I can't sleep until it's solved:

I've got a tree with some parents, which have children, which have also children etc.

Now I need a function to get all nodes from the tree.

This is what currently works, but only with one level of depth:

def nodes_from_tree(tree, parent):
    r = []
    if len(tree.get_children(parent)) == 0:
        return parent
    for child in tree.get_children(parent):
        r.append(nodes_from_tree(tree, child))
    return r

Then I tried to pass r through, so it remembers the children, but I'm using the function more then once and r stores cumulatively all nodes, although I'm setting it to r=[]:

def nodes_from_tree(tree, parent, r=[]):
    r = []
    if len(tree.get_children(parent)) == 0:
        return parent
    for child in tree.get_children(parent):
        r.append(nodes_from_tree(tree, child, r))
    return r

Edit: This is the tree structure:

parent1    parent2    parent3
   |          |          |
   |          |          |
 child        |          |
              |          |
      +--------------+   |
      |       |      |   |
    child   child  child |
      |                  |
  +---+---+              |
child   child        +---+---+
                     |       |
                   child     |
                             |
                       +-----+-----+-----+
                       |     |     |     |
                     child child child child

Available methods:

tree.get_parents()       # returns the nodes of the very top level
tree.get_children(node)  # returns the children of parent or child

Upvotes: 2

Views: 8799

Answers (2)

abarnert
abarnert

Reputation: 366083

I think your problem is just that you're accumulating things incorrectly.

First, if you hit an intermediate node, each child should return a list, but you're appending that list instead of extending it. So, instead of [1, 2, 3, 4] you're going to get something like [[1, 2], [3, 4]]—in other words, you're just transforming it into a list-of-list tree, not a flat list. Change this to extend.

Second, if you hit a leaf node, you're not returning a list at all, just parent. Change this to return [parent].

Third, if you hit an intermediate node, you don't include parent anywhere, so you're only going to end up with the leaves. But you wanted all the nodes. So change the r = [] to r = [parent].

And with that last change, you don't need the if block at all. If there are no children, the loop will happen 0 times, and you'll end up returning [parent] as-is, exactly as you wanted to.

So:

def nodes_from_tree(tree, parent, r=[]):
    r = [parent]
    for child in tree.get_children(parent):
        r.extend(nodes_from_tree(tree, child, r))
    return r

Meanwhile, while this version will work, it's still confused. You're mixing up two different styles of recursion. Passing an accumulator down the chain and adding to on the way down is one way to do it; returning values up the chain and accumulating results on the way up is the other. You're doing half of each.

As it turns out, the way you're doing the upstream recursion is making the downstream recursion have no effect at all. While you do pass an r down to each child, you never modify it, or even use it; you just create a new r list and return that.

The easiest way to fix that is to just remove the accumulator argument:

def nodes_from_tree(tree, parent):
    r = [parent]
    for child in tree.get_children(parent):
        r.extend(nodes_from_tree(tree, child))
    return r

(It's worth noting that branching recursion can only be tail-call-optimized if you do it in downstream accumulator style instead of upstream gathering style. But that doesn't really matter in Python, because Python doesn't do tail call optimization. So, write whichever one makes more sense to you.)

Upvotes: 6

HennyH
HennyH

Reputation: 7944

If I understand your question, you want to make a flat list containing all the values in a tree, in which case a tree represented by tuples the following would work:

def nodes_from_tree(tree,nodes=list()):
    if isinstance(tree,tuple):
        for child in tree:
            nodes_from_tree(child,nodes=nodes)
    else:
        nodes.append(tree)

mynodes = []
tree = (('Root',
        ('Parent',(
            ('Child1',),
            ('Child2',)
            )
        ),
        ('Parent2',(
            ('child1',(
                ('childchild1','childchild2')
            )),
            ('child2',),
            ('child3',)
        )),
        ('Parent3',(
            ('child1',),
            ('child2',(
                ('childchild1',),
                ('childchild2',),
                ('childchild3',),
                ('childchild4',)
            ))
        ))
    ))
nodes_from_tree(tree,nodes=mynodes)
print(mynodes)

Produces

['Root', 'Parent', 'Child1', 'Child2', 'Parent2', 'child1', 'childchild1', 'childchild2',
 'child2', 'child3', 'Parent3', 'child1', 'child2', 'childchild1', 'childchild2', 'childchild3', 'childchild4']

Upvotes: 3

Related Questions