Chris Parry
Chris Parry

Reputation: 3057

Selecting multiple patches from a 3D numpy array

I have a 3D numpy array, of size 50x50x4. I also have the locations of several points on the 50x50 plane. For each point, I need to extract a 11x11x4 region, centred on the point. This region must wrap around, if it overlaps the boundary. What is the most efficient way to do this please?

I am currently using a for loop to iterate over each point, subset the 3D matrix, and store it in a pre-init array. Is there a built in numpy function that does this? Thank you.


Sorry for the slow response, thank you very much for your input everyone.

Upvotes: 1

Views: 1490

Answers (2)

Divakar
Divakar

Reputation: 221624

One approach would be to use np.pad with wrapping functionality along the last axis. Then, we would create sliding windows on this padded version with np.lib.stride_tricks.as_strided, which being views into the padded array won't occupy anymore memory. Finally, we would index into the sliding windows to get the final output.

# Based on http://stackoverflow.com/a/41850409/3293881
def patchify(img, patch_shape): 
    X, Y, a = img.shape
    x, y = patch_shape
    shape = (X - x + 1, Y - y + 1, x, y, a)
    X_str, Y_str, a_str = img.strides
    strides = (X_str, Y_str, X_str, Y_str, a_str)
    return np.lib.stride_tricks.as_strided(img, shape=shape, strides=strides)

def sliding_patches(a, BSZ):
    hBSZ = (BSZ-1)//2
    a_ext = np.dstack(np.pad(a[...,i], hBSZ, 'wrap') for i in range(a.shape[2]))
    return patchify(a_ext, (BSZ,BSZ))

Sample run -

In [51]: a = np.random.randint(0,9,(4,5,2)) # Input array

In [52]: a[...,0]
Out[52]: 
array([[2, 7, 5, 1, 0],
       [4, 1, 2, 0, 7],
       [1, 3, 0, 8, 4],
       [8, 0, 5, 2, 7]])

In [53]: a[...,1]
Out[53]: 
array([[0, 3, 3, 8, 7],
       [3, 8, 2, 8, 2],
       [8, 4, 3, 8, 7],
       [6, 6, 8, 5, 5]])

Now, let's select one center point in a, let's say (1,0) and try to get patches of blocksize (BSZ) = 3 around it -

In [54]: out = sliding_patches(a, BSZ=3) # Create sliding windows

In [55]: out[1,0,...,0]  # patch centered at (1,0) for slice-0
Out[55]: 
array([[0, 2, 7],
       [7, 4, 1],
       [4, 1, 3]])

In [56]: out[1,0,...,1]  # patch centered at (1,0) for slice-1
Out[56]: 
array([[7, 0, 3],
       [2, 3, 8],
       [7, 8, 4]])

So, the final output to get patches around (1,0) would be simply : out[1,0,...,:] i.e. out[1,0].

Let's do a shape check on the original shaped array anyway -

In [65]: a = np.random.randint(0,9,(50,50,4))

In [66]: out = sliding_patches(a, BSZ=11)

In [67]: out[1,0].shape
Out[67]: (11, 11, 4)

Upvotes: 1

Paul Panzer
Paul Panzer

Reputation: 53079

Depending on how many times you have to do this one easy and efficient way would be to pad your original array:

p = np.concatenate([a[-5:, ...], a, a[:5, ...]], axis=0)
p = np.concatenate([p[:, -5:, :], p, p[:, :5, :]], axis=1)

then you can simply slice

s = p[x0 : x0 + 11, x1 : x1 + 11, :]

Upvotes: 0

Related Questions