Ludisposed
Ludisposed

Reputation: 1769

Python group connecting values

I have a list like this:

board = [["X", "X", "X", "X"],
         ["X", "O", "O", "X"],
         ["X", "X", "O", "X"],
         ["X", "O", "X", "X"]]

How can I group all connecting neighbours?

I can find all individual "O"'s like this for instance,

all_o = [
            (x, y)
            for x, row in enumerate(board)
            for y, col in enumerate(row)
            if col == 'O'
        ]

Yielding me this: [(1, 1), (1, 2), (2, 2), (3, 1)]


Now I want to group by the connected items:

Where my expected output would be this: [[(1, 1), (1, 2), (2, 2)], [(3, 1)]]

I don't quite know ho to approach this, been looking at itertools.groupby, is that something I could use here?

Upvotes: 0

Views: 282

Answers (2)

Antoine Zambelli
Antoine Zambelli

Reputation: 764

Not sure if there are built-in functions for this, but here is a simple brute-force algorithm to get you started:

EDIT: rewrote the brute-force code to handle concave groupings.

import numpy as np

board = [[0, 0, 0, 0],
 [0, 1, 0, 1],
 [0, 1, 1, 1],
 [0, 0, 0, 0],
 [1, 1, 0, 0]]

board = np.array(board)
board = np.pad(board,(1,1),'constant',constant_values=(0,0))
tt = np.where(board == 1)
idx_match = zip(tt[0],tt[1])

label_mask = np.zeros(board.shape)
labels = []
labels.append(1)
label_mask[idx_match[0]] = 1#initial case

#this loop labels convex clusters, but can't tell if they are touching
for i in idx_match[1:]:
    pts = [(i[0],i[1]-1),(i[0]-1,i[1]),(i[0],i[1]+1),(i[0]+1,i[1])]
    idx = []
    for pt in pts:
        if label_mask[pt] in labels:
            idx.append(labels.index(label_mask[pt]))

    if idx:
        idx = min(idx)
        label_mask[i] = labels[idx]
    else:
        labels.append(labels[-1] + 1)
        label_mask[i] = labels[-1]

#this loop goes through detected clusters and groups them together if adjacent
for lab in labels:
    adj_vals = []
    tt = np.where(label_mask == lab)
    idx = zip(tt[0],tt[1])

    for i in idx:
        pts = [(i[0],i[1]-1),(i[0]-1,i[1]),(i[0],i[1]+1),(i[0]+1,i[1])]
        adj_vals.extend([label_mask[pt] for pt in pts if label_mask[pt] != 0])

    adj_vals = list(set(adj_vals))

if adj_vals:
    label_mask[np.isin(label_mask,adj_vals)] = adj_vals[0]

label_mask = label_mask[1:-1,1:-1]#remove padding
print(label_mask)

This gives us:

[[ 0.  0.  0.  0.]
 [ 0.  1.  0.  1.]
 [ 0.  1.  1.  1.]
 [ 0.  0.  0.  0.]
 [ 3.  3.  0.  0.]]

(pre-edit, the (0,3) point would have been labeled as 2)

Upvotes: 2

Cihan
Cihan

Reputation: 2307

How about doing a depth-first search to find all connected components?

def dfs(board, start):
    visited, stack = set(), [start]
    while stack:
        vertex = stack.pop()
        cur_i, cur_j = vertex[0], vertex[1]

        if vertex not in visited:
            visited.add(vertex)
            for i,j in [(-1,0),(1,0),(0,-1),(0,1)]:
                if (cur_i+i, cur_j+j) not in visited and 0 <= cur_i+i < len(board) and 0 <= cur_j+j < len(board[0]) and board[cur_i+i][cur_j+j] == "O":
                    stack.append((cur_i+i, cur_j+j))
    return list(visited)

board = [["X", "X", "X", "X"],
         ["X", "O", "O", "X"],
         ["X", "X", "O", "X"],
         ["X", "O", "X", "X"]]

all_o = [(x, y) for x, row in enumerate(board) for y, col in enumerate(row) if col == 'O']

res = []
done = set()
for cur in all_o:
    if cur not in done:
        comps = dfs(board, cur)
        done |= set(comps)
        res.append(comps)
print(res)

Output:

[[(1, 2), (1, 1), (2, 2)], [(3, 1)]]

Reference: I modified the DFS code from this blog post.

Upvotes: 1

Related Questions