2daaa
2daaa

Reputation: 2888

How do I select a window from a numpy array with periodic boundary conditions?

Suppose I make a 2d array like this:

>>> A=np.arange(16).reshape((4,4))
>>> A
array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11],
       [12, 13, 14, 15]])

and I want to be able to select a 3x3 window around any given element so that the window wraps around the boundaries how would I do that? I know I can do this if the boundaries of the window don't overlap the boundaries of the original array:

>>> A[1:4,0:3]
array([[ 4,  5,  6],
       [ 8,  9, 10],
       [12, 13, 14]])

but if I use an expression like A[i-1:i+2,j-1:j+2] it only returns an empty array for i=0, j=0 for example.

Upvotes: 14

Views: 5486

Answers (2)

gauku
gauku

Reputation: 11

I can't comment yet but wanted to suggest an improvement over unutbu's solution:

Their solution can't handle cases like these

A=np.arange(25).reshape((5,5))

print(A) 
# [[ 0  1  2  3  4]
#  [ 5  6  7  8  9]
#  [10 11 12 13 14]
#  [15 16 17 18 19]
#  [20 21 22 23 24]]

print(neighbors(A, 0, 0, n=5))
# [[24 20 21 22 23]
#  [ 4  0  1  2  3]
#  [ 9  5  6  7  8]
#  [14 10 11 12 13]
#  [19 15 16 17 18]]

0 should have been in the center but is off by one row and col.

A small modificatiion in shift values fixes it

def neighbors_updated(arr, x, y, n_row=3, n_col=3):
    ''' Given a 2D-array, returns an nxn array whose "center" element is arr[x,y]'''
    arr=np.roll(np.roll(arr,shift=-x+int(n_row/2),axis=0),shift=-y+int(n_col/2),axis=1)
    return arr[:n_row,:n_col]

print(neighbors(A, 0, 0, n_row=5, n_col=5))
# [[18 19 15 16 17]
#  [23 24 20 21 22]
#  [ 3  4  0  1  2]
#  [ 8  9  5  6  7]
#  [13 14 10 11 12]]

Upvotes: 1

unutbu
unutbu

Reputation: 880289

import numpy as np

A=np.arange(16).reshape((4,4))

def neighbors(arr,x,y,n=3):
    ''' Given a 2D-array, returns an nxn array whose "center" element is arr[x,y]'''
    arr=np.roll(np.roll(arr,shift=-x+1,axis=0),shift=-y+1,axis=1)
    return arr[:n,:n]

print(A)
# [[ 0  1  2  3]
#  [ 4  5  6  7]
#  [ 8  9 10 11]
#  [12 13 14 15]]

print(neighbors(A,0,0))
# [[15 12 13]
#  [ 3  0  1]
#  [ 7  4  5]]

print(neighbors(A,1,0))
# [[ 3  0  1]
#  [ 7  4  5]
#  [11  8  9]]

Upvotes: 16

Related Questions