Reputation: 809
Suppose I have an array a
of shape (2, 2, 2):
a = np.array([[[7, 9],
[19, 18]],
[[24, 5],
[18, 11]]])
and an array b
that is the max of a
: b=a.max(-1)
(row-wise):
b = np.array([[9, 19],
[24, 18]])
I'd like to obtain the index of elements in b
using index in flattened a
, i.e. a.reshape(-1)
:
array([ 7, 9, 19, 18, 24, 5, 18, 11])
The result should be an array that is the same shape with b
with indices of b
in flattened a
:
array([[1, 2],
[4, 6]])
Basically this is the result of maxpool2d when return_indices= True in pytorch, but I'm looking for an implementation in numpy. I've used where
but it seems doesn't work, also is it possible to combine finding max and indices in one go, to be more efficient? Thanks for any help!
Upvotes: 4
Views: 2023
Reputation: 189
I have a solution similar to that of Andras based on np.argmax and np.arange. Instead of "indexing the index" I propose to add a piecewise offset to the result of np.argmax:
import numpy as np
a = np.array([[[7, 9],
[19, 18]],
[[24, 5],
[18, 11]]])
off = np.arange(0, a.size, a.shape[2]).reshape(a.shape[0], a.shape[1])
>>> off
array([[0, 2],
[4, 6]])
This results in:
>>> a.argmax(-1) + off
array([[1, 2],
[4, 6]])
Or as a one-liner:
>>> a.argmax(-1) + np.arange(0, a.size, a.shape[2]).reshape(a.shape[0], a.shape[1])
array([[1, 2],
[4, 6]])
Upvotes: 1
Reputation: 762
This is a long one-liner
new = np.array([np.where(a.reshape(-1)==x)[0][0] for x in a.max(-1).reshape(-1)]).reshape(2,2)
print(new)
array([[1, 2],
[4, 3]])
However number = 18 is repeated twice; So which index is the target.
Upvotes: 0
Reputation: 35080
The only solution I could think of right now is generating a 2d (or 3d, see below) range
that indexes your flat array, and indexing into that with the maximum indices that define b
(i.e. a.argmax(-1)
):
import numpy as np
a = np.array([[[ 7, 9],
[19, 18]],
[[24, 5],
[18, 11]]])
multi_inds = a.argmax(-1)
b_shape = a.shape[:-1]
b_size = np.prod(b_shape)
flat_inds = np.arange(a.size).reshape(b_size, -1)
flat_max_inds = flat_inds[range(b_size), multi_inds.ravel()]
max_inds = flat_max_inds.reshape(b_shape)
I separated the steps with some meaningful variable names, which should hopefully explain what's going on.
multi_inds
tells you which "column" to choose in each "row" in a
to get the maximum:
>>> multi_inds
array([[1, 0],
[0, 0]])
flat_inds
is a list of indices, from which one value is to be chosen in each row:
>>> flat_inds
array([[0, 1],
[2, 3],
[4, 5],
[6, 7]])
This is indexed into exactly according to the maximum indices in each row. flat_max_inds
are the values you're looking for, but in a flat array:
>>> flat_max_inds
array([1, 2, 4, 6])
So we need to reshape that back to match b.shape
:
>>> max_inds
array([[1, 2],
[4, 6]])
A slightly more obscure but also more elegant solution is to use a 3d index array and use broadcasted indexing into it:
import numpy as np
a = np.array([[[ 7, 9],
[19, 18]],
[[24, 5],
[18, 11]]])
multi_inds = a.argmax(-1)
i, j = np.indices(a.shape[:-1])
max_inds = np.arange(a.size).reshape(a.shape)[i, j, multi_inds]
This does the same thing without an intermediate flattening into 2d.
The last part is also how you can get b
from multi_inds
, i.e. without having to call a *max
function a second time:
b = a[i, j, multi_inds]
Upvotes: 1