Reputation: 186
I am quite a beginner to Python and trying to get my head around generators and specifically using the yield statement. Playing around by writing some classic Tree class, that stores keys and data.
#!/usr/bin/env python3
class Tree:
def __init__(self, key, data):
"Create a new Tree object with empty L & R subtrees."
self.key = key
# store passed data
self.data = data
self.left = self.right = None
def insert(self, key, data):
"Insert a new element and data into the tree in the correct position."
if key < self.key:
if self.left:
self.left.insert(key,data)
else:
self.left = Tree(key, data)
elif key > self.key:
if self.right:
self.right.insert(key, data)
else:
self.right = Tree(key, data)
else:
raise ValueError("Attempt to insert duplicate value")
def walk(self):
"Generate the keys and data from the tree in sorted order."
if self.left:
for n in self.left.walk():
yield n
# change output to include data
yield self.key,self.data
if self.right:
for n in self.right.walk():
yield n
This works quite nicely so far. Now I am trying to implement a find() function that walks the tree and returns the data of a key found.
def find(self, key):
if self.left:
for n in self.left.find(key):
yield n
if self.right:
for n in self.right.find(key):
yield n
if self.key == key:
yield self.data
The function works - but I want to raise a KeyError
if the key is nowhere to be found in the tree. I tried to wrap my head around it, but I don't see a (simple) way to do this when using the yield statements. Specifically, I don't seem to be able to come up with a way to actually know when the tree has been completely walked and still the key hasn't been found.
Thanks in advance!
Upvotes: 2
Views: 328
Reputation: 214949
I notice that find
doesn't use the fact that the tree is sorted. How about this implementation:
def find(self, key):
if key == self.key:
return self.data
if key < self.key and self.left:
return self.left.find(key)
if key > self.key and self.right:
return self.right.find(key)
raise KeyError("No such thing")
Upvotes: 2
Reputation: 19769
Rename your current find() as _find(), then:
def find(self, key):
gen = self._find(key)
try:
yield gen.next()
except StopIteration:
raise KeyError(key)
for item in gen:
yield item
Upvotes: 1