Reputation: 13
I'm trying to create a Uniform Cost Search algorithm. but I'm having a problem in storing nodes in the priorityqueue.
It works well till node D as shown in the output provided, and I'm not sure why. Any help will be appreciated.
The error says it can't compare nodes but I'm adding them as tuples so it can use the distacne for comparison
class GraphEdge(object):
def __init__(self, destinationNode, distance):
self.node = destinationNode
self.distance = distance
class GraphNode(object):
def __init__(self, val):
self.value = val
self.edges = []
def add_child(self, node, distance):
self.edges.append(GraphEdge(node, distance))
def remove_child(self, del_node):
if del_node in self.edges:
self.edges.remove(del_node)
class Graph(object):
def __init__(self, node_list):
self.nodes = node_list
def add_edge(self, node1, node2, distance):
if node1 in self.nodes and node2 in self.nodes:
node1.add_child(node2, distance)
node2.add_child(node1, distance)
def remove_edge(self, node1, node2):
if node1 in self.nodes and node2 in self.nodes:
node1.remove_child(node2)
node2.remove_child(node1)
from queue import PriorityQueue
def build_path(root_node, goal_node):
path = [goal_node]
add_parent(root_node, goal_node, path)
return path
def add_parent(root_node, node, path):
parent = node.parent
path.append(parent)
if parent == root_node:
return
else:
add_parent(root_node, parent, path)
def ucs_search(root_node, goal_node):
visited = set()
queue = PriorityQueue()
queue.put((0, root_node))
visited_order = []
while not queue.empty():
current_node_priority, current_node = queue.get()
visited.add(current_node)
visited_order.append(current_node.value)
print("current_node:", current_node.value)
if current_node == goal_node:
print(visited_order)
return current_node, build_path(root_node, goal_node)
for edge in current_node.edges:
child = edge.node
if child not in visited:
child.parent = current_node
print("child:", child.value)
queue.put(((current_node_priority + edge.distance), child))
node_u = GraphNode('U')
node_d = GraphNode('D')
node_a = GraphNode('A')
node_c = GraphNode('C')
node_i = GraphNode('I')
node_t = GraphNode('T')
node_y = GraphNode('Y')
graph = Graph([node_u, node_d, node_a, node_c, node_i, node_t, node_y])
graph.add_edge(node_u, node_a, 4)
graph.add_edge(node_u, node_c, 6)
graph.add_edge(node_u, node_d, 3)
graph.add_edge(node_d, node_c, 4)
graph.add_edge(node_a, node_i, 7)
graph.add_edge(node_c, node_i, 4)
graph.add_edge(node_c, node_t, 5)
graph.add_edge(node_i, node_y, 4)
graph.add_edge(node_t, node_y, 5)
goal, sequence = ucs_search(node_a, node_y)
Output:
current_node: A
child: U
child: I
current_node: U
child: C
child: D
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-52-2d575db64232> in <module>
19 graph.add_edge(node_t, node_y, 5)
20
---> 21 goal, sequence = ucs_search(node_a, node_y)
<ipython-input-51-b26ec19983b6> in ucs_search(root_node, goal_node)
36 child.parent = current_node
37 print("child:", child.value)
---> 38 queue.put(((current_node_priority + edge.distance), child))
39
~\AppData\Local\Continuum\anaconda3\lib\queue.py in put(self, item, block, timeout)
147 raise Full
148 self.not_full.wait(remaining)
--> 149 self._put(item)
150 self.unfinished_tasks += 1
151 self.not_empty.notify()
~\AppData\Local\Continuum\anaconda3\lib\queue.py in _put(self, item)
231
232 def _put(self, item):
--> 233 heappush(self.queue, item)
234
235 def _get(self):
TypeError: '<' not supported between instances of 'GraphNode' and 'GraphNode'
Upvotes: 1
Views: 143
Reputation: 19223
If two tuples in the queue have the same distance, the priority queue needs to tiebreak based on the priority value of the corresponding GraphNode
s. Since the __lt__
function isn't defined for GraphNode
s, this will cause an error. (The __lt__
function defines how two GraphNode
s can be compared using the <
operator.)
To resolve, define the __lt__
function for the GraphNode
class. This is the function that Python calls when comparing two GraphNodes
:
class GraphNode(object):
def __init__(self, val):
self.value = val
self.edges = []
def add_child(self, node, distance):
self.edges.append(GraphEdge(node, distance))
def remove_child(self, del_node):
if del_node in self.edges:
self.edges.remove(del_node)
def __lt__(self, other):
return self.value < other.value
Upvotes: 0