Tomas Novotny
Tomas Novotny

Reputation: 8317

Algorithm for counting connected components of a graph in Python

I try to write a script that counts connected components of a graph and I can't get the right solution. I have a simple graph with 6 nodes (vertexes), nodes 1 and 2 are connected, and nodes 3 and 4 are connected (6 vertexes; 1-2,3-4,5,6). So the graph contains 4 connected components. I use following script to count connected components, but I get wrong result (2).

nodes = [[1, [2], False], [2, [1], False], [3, [4], False], [4, [3], False], [5, [], False], [6, [], False]]
# 6 nodes, every node has an id, list of connected nodes and boolean whether the node has already been visited    

componentsCount = 0

def mark_nodes( list_of_nodes):
    global componentsCount
    componentsCount = 0
    for node in list_of_nodes:
      node[2] = False
      mark_node_auxiliary( node)

def mark_node_auxiliary( node): 
    global componentsCount
    if not node[2] == True: 
      node[2] = True
      for neighbor in node[1]:
        nodes[neighbor - 1][2] = True
        mark_node_auxiliary( nodes[neighbor - 1])
    else:
      unmarkedNodes = []
      for neighbor in node[1]:
        if not nodes[neighbor - 1][2] == True:  # This condition is never met. WHY???
          unmarkedNodes.append( neighbor)
          componentsCount += 1   
      for unmarkedNode in unmarkedNodes:
        mark_node_auxiliary( nodes[unmarkedNode - 1])

def get_connected_components_number( graph):
    result = componentsCount
    mark_nodes( graph)
    for node in nodes:
      if len( node[1]) == 0:      # For every vertex without neighbor...  
        result += 1               # ... increment number of connected components by 1.
    return result

print get_connected_components_number( nodes)

Can anyone please help me find the mistake?

Upvotes: 1

Views: 10784

Answers (3)

Yilmaz
Yilmaz

Reputation: 49291

I place my answer here to cache my learning. I solve with depth first search.

given an adjacency list, its graph looks like this:

enter image description here

Depth first search, is recursively touching all the nodes in graph. this part is simple:

    count=0
    # I touch all the nodes and if dfs returns True, count+=1
    for node in graph:
        if dfs(node):
            count+=1

Now we should write the logic inside dfs. If we start from node 0, we mark it visited, and then we visit its neighbor. As we visit the neighbors, eventually we visit node 2, if we reach node 2 that means graph is connected so we return True.

    def dfs(node):
        if node in visited:
            return False
        visited.add(node)
        for neighbor in graph[node]:
            dfs(neighbor)
        # If I get here that means i explored all
        return True

we started from node 0, we visited till node 2, we returned True. Since I wrote for node in graph:, now it will start from node 1, but since node 1 already visited it will return False.

Here is the full code:

class Solution:
    def count_connected(self,graph):
        visited=set()
        count=0
        def dfs(node):
            if node in visited:
                return False
            visited.add(node)
            for neighbor in graph[node]:
                dfs(neighbor)
            # If I get here that means i explored all
            return True
        # explore all the neightbors of nodes
        for node in graph:
            if dfs(node):
                count+=1
        return count

Upvotes: 1

user97370
user97370

Reputation:

A disjoint-set datastructure will really help you to write clear code here, see Wikipedia.

The basic idea is that you associate a set with each node in your graph, and for each edge you merge the sets of its two endpoints. Two sets x and y are the same if x.find() == y.find()

Here's the most naive implementation (which has bad worst-case complexity), but there's a couple of optimisations of the DisjointSet class on the wikipedia page above which in a handful of extra lines of code make this efficient. I omitted them for clarity.

nodes = [[1, [2]], [2, [1]], [3, [4]], [4, [3]], [5, []], [6, []]]

def count_components(nodes):
    sets = {}
    for node in nodes:
      sets[node[0]] = DisjointSet()
    for node in nodes:
        for vtx in node[1]:
            sets[node[0]].union(sets[vtx])
    return len(set(x.find() for x in sets.itervalues()))

class DisjointSet(object):
    def __init__(self):
        self.parent = None

    def find(self):
        if self.parent is None: return self
        return self.parent.find()

    def union(self, other):
        them = other.find()
        us = self.find()
        if them != us:
            us.parent = them

print count_components(nodes)

Upvotes: 6

dln385
dln385

Reputation: 12090

Sometimes it's easier to write code than to read it.

Put this through some tests, I'm pretty sure it'll always work as long as every connection is bidirectional (such as in your example).

def recursivelyMark(nodeID, nodes):
    (connections, visited) = nodes[nodeID]
    if visited:
        return
    nodes[nodeID][1] = True
    for connectedNodeID in connections:
        recursivelyMark(connectedNodeID, nodes)

def main():
    nodes = [[[1], False], [[0], False], [[3], False], [[2], False], [[], False], [[], False]]
    componentsCount = 0
    for (nodeID, (connections, visited)) in enumerate(nodes):
        if visited == False:
            componentsCount += 1
            recursivelyMark(nodeID, nodes)
    print(componentsCount)

if __name__ == '__main__':
    main()

Note that I removed the ID from the node information since its position in the array is its ID. Let me know if this program doesn't do what you need.

Upvotes: 5

Related Questions