Reputation: 229
I have a 2D Numpy ndarray, x
, that I need to split in square subregions of size s
. For each subregion, I want to get the greatest element (which I do), and its position within that subregion (which I can't figure out).
Here is a minimal example:
>>> x = np.random.randint(0, 10, (6,8))
>>> x
array([[9, 4, 8, 9, 5, 7, 3, 3],
[3, 1, 8, 0, 7, 7, 5, 1],
[7, 7, 3, 6, 0, 2, 1, 0],
[7, 3, 9, 8, 1, 6, 7, 7],
[1, 6, 0, 7, 5, 1, 2, 0],
[8, 7, 9, 5, 8, 3, 6, 0]])
>>> h, w = x.shape
>>> s = 2
>>> f = x.reshape(h//s, s, w//s, s)
>>> mx = np.max(f, axis=(1, 3))
>>> mx
array([[9, 9, 7, 5],
[7, 9, 6, 7],
[8, 9, 8, 6]])
For example, the 8
in the lower left corner of mx
is the greatest element from subregion [[1,6], [8, 7]]
in the lower left corner of x
.
What I want is to get an array similar to mx
, that keeps the indices of the largest elements, like this:
[[0, 1, 1, 2],
[0, 2, 3, 2],
[2, 2, 2, 2]]
where, for example, the 2
in the lower left corner is the index of 8
in the linear representation of [[1, 6], [8, 7]]
.
I could do it like this: np.argmax(f[i, :, j, :])
and iterate over i
and j
, but the speed difference is enormous for large amounts of computation. To give you an idea, I'm trying to use (only) Numpy for max pooling. Basically, I'm asking if there is a faster alternative than what I'm using.
Upvotes: 2
Views: 1047
Reputation: 221564
Here's one approach -
# Get shape of output array
m,n = np.array(x.shape)//s
# Reshape and permute axes to bring the block as rows
x1 = x.reshape(h//s, s, w//s, s).swapaxes(1,2).reshape(-1,s**2)
# Use argmax along each row and reshape to output shape
out = x1.argmax(1).reshape(m,n)
Sample input, output -
In [362]: x
Out[362]:
array([[9, 4, 8, 9, 5, 7, 3, 3],
[3, 1, 8, 0, 7, 7, 5, 1],
[7, 7, 3, 6, 0, 2, 1, 0],
[7, 3, 9, 8, 1, 6, 7, 7],
[1, 6, 0, 7, 5, 1, 2, 0],
[8, 7, 9, 5, 8, 3, 6, 0]])
In [363]: out
Out[363]:
array([[0, 1, 1, 2],
[0, 2, 3, 2],
[2, 2, 2, 2]])
Alternatively, to simplify things, we could use scikit-image
that does the heavy work of reshaping and permuting axes for us -
In [372]: from skimage.util import view_as_blocks as viewB
In [373]: viewB(x, (s,s)).reshape(-1,s**2).argmax(1).reshape(m,n)
Out[373]:
array([[0, 1, 1, 2],
[0, 2, 3, 2],
[2, 2, 2, 2]])
Upvotes: 2