hkk
hkk

Reputation: 2149

How to Find the Neighbors of a Cell in an ndarray?

I'm working with n-dimensional arrays in Python, and I want to find the "neighbors" (adjacent cells) of a given cell based on its coordinates. The issue is that I don't know the number of dimensions beforehand.

I attempted to use numpy.roll as suggested by this answer, but it seems unclear how to apply this method to multiple dimensions.

Please point me in the right direction.

Upvotes: 2

Views: 3997

Answers (1)

ali_m
ali_m

Reputation: 74262

I'm going to assume that you have an (ndims,) vector of indices specifying some point p, and you want an (m, ndims) array of indices corresponding to the locations of every adjacent element in the array (including diagonally adjacent elements).

Starting out with your indexing vector p, you want to offset each element by every possible combination of -1, 0 and +1. This can be done by using np.indices to generate an (m, ndims) array of offsets, then adding these offsets to p.

You might want to exclude point p itself (i.e. where offset == np.array([0, 0, ..., 0]), and you may also need to exclude out-of-bounds indices.

import numpy as np

def get_neighbours(p, exclude_p=True, shape=None):

    ndim = len(p)

    # generate an (m, ndims) array containing all strings over the alphabet {0, 1, 2}:
    offset_idx = np.indices((3,) * ndim).reshape(ndim, -1).T

    # use these to index into np.array([-1, 0, 1]) to get offsets
    offsets = np.r_[-1, 0, 1].take(offset_idx)

    # optional: exclude offsets of 0, 0, ..., 0 (i.e. p itself)
    if exclude_p:
        offsets = offsets[np.any(offsets, 1)]

    neighbours = p + offsets    # apply offsets to p

    # optional: exclude out-of-bounds indices
    if shape is not None:
        valid = np.all((neighbours < np.array(shape)) & (neighbours >= 0), axis=1)
        neighbours = neighbours[valid]

    return neighbours

Here's a 2D example that's easy to visualize:

p = np.r_[4, 5]
shape = (6, 6)

neighbours = get_neighbours(p, shape=shape)

x = np.zeros(shape, int)
x[tuple(neighbours.T)] = 1
x[tuple(p)] = 2

print(x)
# [[0 0 0 0 0 0]
#  [0 0 0 0 0 0]
#  [0 0 0 0 0 0]
#  [0 0 0 0 1 1]
#  [0 0 0 0 1 2]
#  [0 0 0 0 1 1]]

This will generalize to any dimensionality.


If you just want to be able to index the "neighbourhood" of p and you don't care about excluding p itself, a much simpler and faster option would be to use a tuple of slice objects:

idx = tuple(slice(pp - 1, pp + 2) for pp in p)
print(x[idx])
# [[1 1]
#  [1 2]
#  [1 1]]

Upvotes: 7

Related Questions