FaCoffee
FaCoffee

Reputation: 7909

Networkx: all Spanning Trees and their associated total weight

Given a simple undirected grid network like this:

import networkx as nx
from pylab import *
import matplotlib.pyplot as plt
%pylab inline

ncols=3
N=3 
G=nx.grid_2d_graph(N,N)
labels = dict( ((i,j), i + (N-1-j) * N ) for i, j in G.nodes() )
nx.relabel_nodes(G,labels,False)
inds=labels.keys()
vals=labels.values()
inds=[(N-j-1,N-i-1) for i,j in inds]
pos2=dict(zip(vals,inds))
nx.draw_networkx(G, pos=pos2, with_labels=True, node_size = 200, node_color='orange',font_size=10)
plt.axis('off')
plt.title('grid')
plt.show()

And given that each edge has a weight corresponding to its length:

#Weights
from math import sqrt

weights = dict()
for source, target in G.edges():
    x1, y1 = pos2[source]
    x2, y2 = pos2[target]
    weights[(source, target)] = round((math.sqrt((x2-x1)**2 + (y2-y1)**2)),3) 

for e in G.edges():
    G[e[0]][e[1]] = weights[e] #Assigning weights to G.edges()

How could it be possible to compute all spanning trees in the grid, and their associated total weight?

NB: this is a trivial case where all weights=1.

Upvotes: 1

Views: 3611

Answers (1)

Paul Brodersen
Paul Brodersen

Reputation: 13031

This took way longer than expected, but the following code finds all spanning trees for the general case. Getting the associated total weight should be trivial, as you have access to the edgelist of each tree.

Don't use this on very large trees -- even the toy example yields 192 spanning trees.

import numpy as np
import matplotlib.pyplot as plt
import networkx as nx

def _expand(G, explored_nodes, explored_edges):
    """
    Expand existing solution by a process akin to BFS.

    Arguments:
    ----------
    G: networkx.Graph() instance
        full graph

    explored_nodes: set of ints
        nodes visited

    explored_edges: set of 2-tuples
        edges visited

    Returns:
    --------
    solutions: list, where each entry in turns contains two sets corresponding to explored_nodes and explored_edges
        all possible expansions of explored_nodes and explored_edges

    """
    frontier_nodes = list()
    frontier_edges = list()
    for v in explored_nodes:
        for u in nx.neighbors(G,v):
            if not (u in explored_nodes):
                frontier_nodes.append(u)
                frontier_edges.append([(u,v), (v,u)])

    return zip([explored_nodes | frozenset([v]) for v in frontier_nodes], [explored_edges | frozenset(e) for e in frontier_edges])

def find_all_spanning_trees(G, root=0):
    """
    Find all spanning trees of a Graph.

    Arguments:
    ----------
    G: networkx.Graph() instance
        full graph

    Returns:
    ST: list of networkx.Graph() instances
        list of all spanning trees

    """

    # initialise solution
    explored_nodes = frozenset([root])
    explored_edges = frozenset([])
    solutions = [(explored_nodes, explored_edges)]
    # we need to expand solutions number_of_nodes-1 times
    for ii in range(G.number_of_nodes()-1):
        # get all new solutions
        solutions = [_expand(G, nodes, edges) for (nodes, edges) in solutions]
        # flatten nested structure and get unique expansions
        solutions = set([item for sublist in solutions for item in sublist])

    return [nx.from_edgelist(edges) for (nodes, edges) in solutions]


if __name__ == "__main__":

    N = 3
    G = nx.grid_2d_graph(N,N)
    labels = dict( ((i,j), i + (N-1-j) * N ) for i, j in G.nodes() )
    nx.relabel_nodes(G,labels,False)
    inds=labels.keys()
    vals=labels.values()
    inds=[(N-j-1,N-i-1) for i,j in inds]
    pos2=dict(zip(vals,inds))

    fig, ax = plt.subplots(1,1)
    nx.draw_networkx(G, pos=pos2, with_labels=True, node_size = 200, node_color='orange',font_size=10,ax=ax)
    plt.axis('off')
    plt.title('grid')

    ST = find_all_spanning_trees(G)
    print len(ST)

    for g in ST:
        fig, ax = plt.subplots(1,1)
        nx.draw_networkx(g, pos=pos2, with_labels=True, node_size = 200, node_color='orange',font_size=10,ax=ax)
        plt.axis('off')
        plt.title('grid')
        plt.show()

Upvotes: 5

Related Questions