Reputation: 706
Currently, I have a 4d array, say,
arr = np.arange(48).reshape((2,2,3,4))
I want to apply a function that takes a 2d array as input to each 2d array sliced from arr
. I have searched and read this question, which is exactly what I want.
The function I'm using is im2col_sliding_broadcasting()
which I get from here. It takes a 2d array and list of 2 elements as input and returns a 2d array. In my case: it takes 3x4
2d array and a list [2, 2]
and returns 4x6
2d array.
I considered using apply_along_axis()
but as said it only accepts 1d
function as parameter. I can't apply im2col
function this way.
I want an output that has the shape as 2x2x4x6
. Surely I can achieve this with for loop, but I heard that it's too time expensive:
import numpy as np
def im2col_sliding_broadcasting(A, BSZ, stepsize=1):
# source: https://stackoverflow.com/a/30110497/10666066
# Parameters
M, N = A.shape
col_extent = N - BSZ[1] + 1
row_extent = M - BSZ[0] + 1
# Get Starting block indices
start_idx = np.arange(BSZ[0])[:, None]*N + np.arange(BSZ[1])
# Get offsetted indices across the height and width of input array
offset_idx = np.arange(row_extent)[:, None]*N + np.arange(col_extent)
# Get all actual indices & index into input array for final output
return np.take(A, start_idx.ravel()[:, None] + offset_idx.ravel()[::stepsize])
arr = np.arange(48).reshape((2,2,3,4))
output = np.empty([2,2,4,6])
for i in range(2):
for j in range(2):
temp = im2col_sliding_broadcasting(arr[i, j], [2,2])
output[i, j] = temp
Since my arr
in fact is a 10000x3x64x64
array. So my question is: Is there another way to do this more efficiently ?
Upvotes: 1
Views: 207
Reputation: 221614
We can leverage np.lib.stride_tricks.as_strided
based scikit-image's view_as_windows
to get sliding windows. More info on use of as_strided
based view_as_windows
.
from skimage.util.shape import view_as_windows
W1,W2 = 2,2 # window size
# create sliding windows along last two axes1
w = view_as_windows(arr,(1,1,W1,W2))[...,0,0,:,:]
# Merge the window axes (tha last two axes) and
# merge the axes along which those windows were created (3rd and 4th axes)
outshp = arr.shape[:-2] + (W1*W2,) + ((arr.shape[-2]-W1+1)*(arr.shape[-1]-W2+1),)
out = w.transpose(0,1,4,5,2,3).reshape(outshp)
The last step forces a copy. So, skip it if possible.
Upvotes: 1