azmo
azmo

Reputation: 186

Raising an Exception from recursive generator function

I am quite a beginner to Python and trying to get my head around generators and specifically using the yield statement. Playing around by writing some classic Tree class, that stores keys and data.

#!/usr/bin/env python3

class Tree:
    def __init__(self, key, data):
        "Create a new Tree object with empty L & R subtrees."
        self.key = key
        # store passed data
        self.data = data
        self.left = self.right = None

    def insert(self, key, data):
        "Insert a new element and data into the tree in the correct position."
        if key < self.key:
            if self.left:
                self.left.insert(key,data)
            else:
                self.left = Tree(key, data)
        elif key > self.key:
            if self.right:
                self.right.insert(key, data)
            else:
                self.right = Tree(key, data)
        else:
            raise ValueError("Attempt to insert duplicate value")

    def walk(self):
        "Generate the keys and data from the tree in sorted order."
        if self.left:
            for n in self.left.walk():
                yield n
        # change output to include data
        yield self.key,self.data
        if self.right:
            for n in self.right.walk():
                yield n

This works quite nicely so far. Now I am trying to implement a find() function that walks the tree and returns the data of a key found.

def find(self, key):
    if self.left:
        for n in self.left.find(key):
            yield n

    if self.right:
        for n in self.right.find(key):
            yield n

    if self.key == key:
        yield self.data

The function works - but I want to raise a KeyError if the key is nowhere to be found in the tree. I tried to wrap my head around it, but I don't see a (simple) way to do this when using the yield statements. Specifically, I don't seem to be able to come up with a way to actually know when the tree has been completely walked and still the key hasn't been found.

Thanks in advance!

Upvotes: 2

Views: 328

Answers (2)

georg
georg

Reputation: 214949

I notice that find doesn't use the fact that the tree is sorted. How about this implementation:

def find(self, key):
    if key == self.key:
        return self.data
    if key < self.key and self.left:
        return self.left.find(key)
    if key > self.key and self.right:
        return self.right.find(key)
    raise KeyError("No such thing")

Upvotes: 2

Matt Anderson
Matt Anderson

Reputation: 19769

Rename your current find() as _find(), then:

def find(self, key):
    gen = self._find(key)
    try:
        yield gen.next()
    except StopIteration:
        raise KeyError(key)
    for item in gen:
        yield item

Upvotes: 1

Related Questions