Reputation: 3
I'm trying to find cycles in an undirected, unweighted graph. In [node, node] format. Here is the code I wrote:
def find_cycles(graph):
cycles = []
def dfs(node, visited, path):
visited.add(node)
path.append(node)
neighbors = graph.get(node, [])
for neighbor in neighbors:
if neighbor in visited:
# Cycle detected
start_index = path.index(neighbor)
cycle = path[start_index:]
if cycle not in cycles:
cycles.append(cycle)
else:
dfs(neighbor, visited, path)
visited.remove(node)
path.pop()
for node in graph.keys():
dfs(node, set(), [])
return cycles
graph = {
# 'node': ['adjacent'],
}
n, m = map(int, input().split())
for _ in range(m):
a, b = map(int, input().split())
if b not in graph:
graph[b] = []
if a not in graph:
graph[a] = []
if a not in graph[b]:
graph[b].append(a)
if b not in graph[a]:
graph[a].append(b)
ans = find_cycles(graph)
print(ans)
print(len(ans))
In the test case:
10 10
3 6
9 3
1 7
1 2
4 7
7 6
2 9
2 6
3 4
6 0
I know that the shortest cycle length is 4, but it prints a wrong list containing 92 items, with the shortest one being of length 2. What is wrong in my code?
Upvotes: 0
Views: 216
Reputation: 54733
I've modified your code very slightly. Instead of checking if cycle in cycles
, I'm maintaining a set of the paths we have seen, with the nodes in sorted order. If the new cycle is not present in the set, then I add it to the list of cycles. I also discard any cycle with only two edges.
With this, I get 7 cycles, and I think it is correct.
data = """10 10
3 6
9 3
1 7
1 2
4 7
7 6
2 9
2 6
3 4
6 0""".splitlines()
def find_cycles(graph):
cycles = []
checked = set()
def dfs(node, visited, path):
visited.add(node)
path.append(node)
neighbors = graph.get(node, [])
for neighbor in neighbors:
if neighbor in visited:
# Cycle detected
start_index = path.index(neighbor)
cycle = path[start_index:]
m = tuple(sorted(cycle))
if len(cycle) > 2 and m not in checked:
checked.add(m)
cycles.append(cycle)
else:
dfs(neighbor, visited, path)
visited.remove(node)
path.pop()
for node in graph.keys():
dfs(node, set(), [])
return cycles
graph = {
# 'node': ['adjacent'],
}
n, m = map(int, data.pop(0).split())
for _ in range(m):
a, b = map(int, data.pop(0).split())
if b not in graph:
graph[b] = []
if a not in graph:
graph[a] = []
if a not in graph[b]:
graph[b].append(a)
if b not in graph[a]:
graph[a].append(b)
ans = find_cycles(graph)
print(ans)
print(len(ans))
Output:
[[3, 9, 2, 1, 7, 4], [6, 3, 9, 2, 1, 7], [6, 3, 9, 2], [6, 3, 4, 7, 1, 2], [6, 3, 4, 7], [6, 7, 1, 2], [6, 7, 4, 3, 9, 2]]
7
Upvotes: 1