Tom Carrick
Tom Carrick

Reputation: 6616

Tree traversal and getting neighbouring child nodes in Python

I'm trying to traverse a tree, and get certain subtrees into a particular data structure. I think an example is the best way to explain it:

enter image description here

For this tree, I want the root node and it's children. Then any children that have their own children should be traversed in the same way, and so on. So for the above tree, we would end up with a data structure such as:

[
    (a, [b, c]),
    (c, [d, e, f]),
    (f, [g, h]),
]

I have some code so far to produce this, but there's an issue that it stops too early (or that's what it seems like):

from spacy.en import English


def _subtrees(sent, root=None, subtrees=[]):
    if not root:
        root = sent.root

    children = list(root.children)
    if not children:
        return subtrees

    subtrees.append((root, [child for child in children]))
    for child in children:
        return _subtrees(sent, child, subtrees)


nlp = English()
doc = nlp('they showed us an example')
print(_subtrees(list(doc.sents)[0]))

Note that this code won't produce the same tree as in the image. I feel like a generator would be better suited here also, but my generator-fu is even worse than my recursion-fu.

Upvotes: 2

Views: 2455

Answers (4)

Zach Rosen
Zach Rosen

Reputation: 1

I can't quite comment yet, but if you modify the response by @syllogism_ like so and it'll omit all nodes that haven't any children in them.

[(word, list(word.children)) for word in s if bool(list(word.children))]

Upvotes: 0

syllogism_
syllogism_

Reputation: 4297

Assuming you want to know this for using spaCy specifically, why not just:

[(word, list(word.children)) for word in sent]

The Doc object lets you iterate over all nodes in order. So you don't need to walk the tree recursively here --- just iterate.

Upvotes: 0

alexis
alexis

Reputation: 50220

Let's first sketch the recursive algorithm:

  • Given a tree node, return:

    1. A tuple of the node with its children
    2. The subtrees of each child.

That's all it takes, so let's convert it to pseudocode, ehm, python:

def subtrees(node):
    if not node.children:
        return []

    result = [ (node.dep, list(node.children)) ]
    for child in node.children:
        result.extend(subtrees(child))

    return result

The root is just a node, so it shouldn't need special treatment. But please fix the member references if I misunderstood the data structure.

Upvotes: 1

David Perez
David Perez

Reputation: 488

def _subtrees(root):

   subtrees=[]
   queue = []
   queue.append(root)
   while(len(queue)=!0):
      root=queue[0]
      children = list(root.children)
      if (children):
         queue = queue + list(root.children)
         subtrees.append((root.dep, [child.dep for child in children]))
      queue=queue.pop(0)

   return subtrees

Upvotes: 0

Related Questions