haidahaida
haidahaida

Reputation: 229

Get index of largest element for each submatrix in a Numpy 2D array

I have a 2D Numpy ndarray, x, that I need to split in square subregions of size s. For each subregion, I want to get the greatest element (which I do), and its position within that subregion (which I can't figure out).

Here is a minimal example:

>>> x = np.random.randint(0, 10, (6,8))
>>> x
array([[9, 4, 8, 9, 5, 7, 3, 3],
       [3, 1, 8, 0, 7, 7, 5, 1],
       [7, 7, 3, 6, 0, 2, 1, 0],
       [7, 3, 9, 8, 1, 6, 7, 7],
       [1, 6, 0, 7, 5, 1, 2, 0],
       [8, 7, 9, 5, 8, 3, 6, 0]])
>>> h, w = x.shape
>>> s = 2
>>> f = x.reshape(h//s, s, w//s, s)
>>> mx = np.max(f, axis=(1, 3))
>>> mx
array([[9, 9, 7, 5],
       [7, 9, 6, 7],
       [8, 9, 8, 6]])

For example, the 8 in the lower left corner of mx is the greatest element from subregion [[1,6], [8, 7]] in the lower left corner of x.

What I want is to get an array similar to mx, that keeps the indices of the largest elements, like this:

[[0, 1, 1, 2],
 [0, 2, 3, 2],
 [2, 2, 2, 2]]

where, for example, the 2 in the lower left corner is the index of 8 in the linear representation of [[1, 6], [8, 7]].

I could do it like this: np.argmax(f[i, :, j, :]) and iterate over i and j, but the speed difference is enormous for large amounts of computation. To give you an idea, I'm trying to use (only) Numpy for max pooling. Basically, I'm asking if there is a faster alternative than what I'm using.

Upvotes: 2

Views: 1047

Answers (1)

Divakar
Divakar

Reputation: 221564

Here's one approach -

# Get shape of output array
m,n = np.array(x.shape)//s

# Reshape and permute axes to bring the block as rows
x1 = x.reshape(h//s, s, w//s, s).swapaxes(1,2).reshape(-1,s**2)

# Use argmax along each row and reshape to output shape
out = x1.argmax(1).reshape(m,n)

Sample input, output -

In [362]: x
Out[362]: 
array([[9, 4, 8, 9, 5, 7, 3, 3],
       [3, 1, 8, 0, 7, 7, 5, 1],
       [7, 7, 3, 6, 0, 2, 1, 0],
       [7, 3, 9, 8, 1, 6, 7, 7],
       [1, 6, 0, 7, 5, 1, 2, 0],
       [8, 7, 9, 5, 8, 3, 6, 0]])

In [363]: out
Out[363]: 
array([[0, 1, 1, 2],
       [0, 2, 3, 2],
       [2, 2, 2, 2]])

Alternatively, to simplify things, we could use scikit-image that does the heavy work of reshaping and permuting axes for us -

In [372]: from skimage.util import view_as_blocks as viewB

In [373]: viewB(x, (s,s)).reshape(-1,s**2).argmax(1).reshape(m,n)
Out[373]: 
array([[0, 1, 1, 2],
       [0, 2, 3, 2],
       [2, 2, 2, 2]])

Upvotes: 2

Related Questions