mol
mol

Reputation: 63

What is wrong with my Dijkstra algorithm for undirected graph?

Undirected graph dijkstra algorithm. Given start node, return a table mapping the shortest path from A to each node and its value.

from heapq import heappush, heappop

def dijkstra(edges, start):

    graph = {}
    for (x, y, z) in edges:
        # A: [('B', 6)]
        graph[x] = graph.get(x, []) + [(y, z)]
        graph[y] = graph.get(y, []) + [(x, z)] # undirected graph
    
    
    table = {}
    for v in graph:
        table[v] = (float("inf"), None) # inf, no previous node
    table[start] = (0, None)
    stack = [(0, start)]
    
    visited = []
    while(stack != []):
        fill, node = heappop(stack)
        w = table[node][0]
        if(node in visited):
            continue
    
        visited.append(node)
        for (v, weight) in graph[node]:
            
            
            cur_weight, prev = table[v]
            if(cur_weight > weight + w):
                table[v] = (weight + w, node)
                
            cur_weight, prev = table[v]   
            heappush(stack, (cur_weight, v))
            
    return table           

edges = [['A', 'C', 1], ['C', 'E', 1], ['E', 'B', 1], ['A', 'B', 10]]
    
print(dijkstra(edges, 'A')) # outputs the correct table

the output is correct above for the table but for extemely large output like (n = 5000) it seems to fail and im unsure why?

Upvotes: 1

Views: 175

Answers (1)

mol
mol

Reputation: 63

Swapped stack to minheap to prevent test cases like what was mentioned in the comments.

from heapq import heappush, heappop

def dijkstra(edges, start):

    graph = {}
    for (x, y, z) in edges:
        # A: [('B', 6)]
        graph[x] = graph.get(x, []) + [(y, z)]
        graph[y] = graph.get(y, []) + [(x, z)] # undirected graph
    
    
    table = {}
    for v in graph:
        table[v] = (float("inf"), None) # inf, no previous node
    table[start] = (0, None)
    stack = [(0, start)]
    
    visited = set()
    while(stack != []):
        w, node = heappop(stack)
        if(node in visited):
            continue
    
        visited.add(node)
        for (v, weight) in graph[node]:
            
            
            cur_weight, prev = table[v]
            if(cur_weight > weight + w):
                table[v] = (weight + w, node)          
                heappush(stack, (weight + w, v))
            
    return table

Edit: Optimized it based on comments below

Upvotes: 1

Related Questions