Batong
Batong

Reputation: 59

Find grandparent node using nltk

I am using the Tree-package from nltk with python 2.7 and I want to extract every rule from a tree with it's grandparent node. I have the following tree

t = Tree('S', [Tree('NP', [Tree('D', ['the']), Tree('N', ['dog'])]), Tree('VP', [Tree('V', ['chased']), Tree('NP', [Tree('D', ['the']), Tree('N', ['cat'])])])])

and the productions

   t.productions
   [S -> NP VP, NP -> D N, D -> 'the', N -> 'dog', VP -> V NP, V -> 'chased', NP -> D N, D -> 'the', N -> 'cat']

for the tree:

               S               
       ________|_____           
      |              VP        
      |         _____|___       
      NP       |         NP    
   ___|___     |      ___|___   
  D       N    V     D       N 
  |       |    |     |       |  
 the     dog chased the     cat

What I want is something on the form:

[S -> NP VP, S ^ NP -> D N, NP ^ D -> 'the', NP ^ N -> 'dog'.......]

I've looked at the ParentedTree class, but I don't get how to use it to solve my problem.

Upvotes: 3

Views: 203

Answers (1)

RAVI
RAVI

Reputation: 3153

You need to modify / overwrite productions method.

Code:

from nltk.tree import Tree
from nltk.compat import string_types
from nltk.grammar import Production, Nonterminal
from nltk.tree import _child_names

def productions(t, parent):
    if not isinstance(t._label, string_types):
        raise TypeError('Productions can only be generated from trees having node labels that are strings')

    # t._label ==> parent + " ^ " + t._label
    prods = [Production(Nonterminal(parent + " ^ " + t._label), _child_names(t))]
    for child in t:
        if isinstance(child, Tree):
            prods += productions(child, t._label)
    return prods


t = Tree('S', [Tree('NP', [Tree('D', ['the']), Tree('N', ['dog'])]), Tree('VP', [Tree('V', ['chased']), Tree('NP', [Tree('D', ['the']), Tree('N', ['cat'])])])])

# To Add Parent of 'S' as 'Start'
# prods = productions(t, "Start")

# To Skip Parent of 'S'
prods = [Production(Nonterminal(t._label), _child_names(t))]
for child in t:
    if isinstance(child, Tree):
        prods += productions(child, t._label)

print prods

Output:

[S -> NP VP, S ^ NP -> D N, NP ^ D -> 'the', 
    NP ^ N -> 'dog', S ^ VP -> V NP, VP ^ V -> 'chased', 
    VP ^ NP -> D N, NP ^ D -> 'the', NP ^ N -> 'cat']

For more information check productions method of nltk.tree - here

Upvotes: 1

Related Questions