Reputation: 2149
I'm working with n-dimensional arrays in Python, and I want to find the "neighbors" (adjacent cells) of a given cell based on its coordinates. The issue is that I don't know the number of dimensions beforehand.
I attempted to use numpy.roll
as suggested by this answer, but it seems unclear how to apply this method to multiple dimensions.
Please point me in the right direction.
Upvotes: 2
Views: 3997
Reputation: 74262
I'm going to assume that you have an (ndims,)
vector of indices specifying some point p
, and you want an (m, ndims)
array of indices corresponding to the locations of every adjacent element in the array (including diagonally adjacent elements).
Starting out with your indexing vector p
, you want to offset each element by every possible combination of -1, 0 and +1. This can be done by using np.indices
to generate an (m, ndims)
array of offsets, then adding these offsets to p
.
You might want to exclude point p
itself (i.e. where offset == np.array([0, 0, ..., 0])
, and you may also need to exclude out-of-bounds indices.
import numpy as np
def get_neighbours(p, exclude_p=True, shape=None):
ndim = len(p)
# generate an (m, ndims) array containing all strings over the alphabet {0, 1, 2}:
offset_idx = np.indices((3,) * ndim).reshape(ndim, -1).T
# use these to index into np.array([-1, 0, 1]) to get offsets
offsets = np.r_[-1, 0, 1].take(offset_idx)
# optional: exclude offsets of 0, 0, ..., 0 (i.e. p itself)
if exclude_p:
offsets = offsets[np.any(offsets, 1)]
neighbours = p + offsets # apply offsets to p
# optional: exclude out-of-bounds indices
if shape is not None:
valid = np.all((neighbours < np.array(shape)) & (neighbours >= 0), axis=1)
neighbours = neighbours[valid]
return neighbours
Here's a 2D example that's easy to visualize:
p = np.r_[4, 5]
shape = (6, 6)
neighbours = get_neighbours(p, shape=shape)
x = np.zeros(shape, int)
x[tuple(neighbours.T)] = 1
x[tuple(p)] = 2
print(x)
# [[0 0 0 0 0 0]
# [0 0 0 0 0 0]
# [0 0 0 0 0 0]
# [0 0 0 0 1 1]
# [0 0 0 0 1 2]
# [0 0 0 0 1 1]]
This will generalize to any dimensionality.
If you just want to be able to index the "neighbourhood" of p
and you don't care about excluding p
itself, a much simpler and faster option would be to use a tuple of slice
objects:
idx = tuple(slice(pp - 1, pp + 2) for pp in p)
print(x[idx])
# [[1 1]
# [1 2]
# [1 1]]
Upvotes: 7