Reputation: 151
I have an numpy array like this with shape (6, 2, 4)
:
x = np.array([[[0, 3, 2, 0],
[1, 3, 1, 1]],
[[3, 2, 3, 3],
[0, 3, 2, 0]],
[[1, 0, 3, 1],
[3, 2, 3, 3]],
[[0, 3, 2, 0],
[1, 3, 2, 2]],
[[3, 0, 3, 1],
[1, 0, 1, 1]],
[[1, 3, 1, 1],
[3, 1, 3, 3]]])
And I have choices
array like this:
choices = np.array([[1, 1, 1, 1],
[0, 1, 1, 0],
[1, 1, 1, 1],
[1, 0, 0, 0],
[1, 0, 1, 1],
[0, 0, 0, 1]])
How can I use choices
array to index only the middle dimension with size 2 and get a new numpy array with shape (6, 4)
in the most efficient way possible?
The result would be this:
[[1 3 1 1]
[3 3 2 3]
[3 2 3 3]
[1 3 2 0]
[1 0 1 1]
[1 3 1 3]]
I've tried to do it by x[:, choices, :]
but this doesn't return what I want. I also tried to do x.take(choices, axis=1)
but no luck.
Upvotes: 3
Views: 5380
Reputation: 31
as I recently had this issue, found @divakar's answer useful, but still wanted a general functions for that (independent of number of dims etc.), here it is :
def take_indices_along_axis(array, choices, choice_axis):
"""
array is N dim
choices are integer of N-1 dim
with valuesbetween 0 and array.shape[choice_axis] - 1
choice_axis is the axis along which you want to take indices
"""
nb_dims = len(array.shape)
list_indices = []
for this_axis, this_axis_size in enumerate(array.shape):
if this_axis == choice_axis:
# means this is the axis along which we want to choose
list_indices.append(choices)
continue
# else, we want arange(this_axis), but reshaped to match the purpose
this_indices = np.arange(this_axis_size)
reshape_target = [1 for _ in range(nb_dims)]
reshape_target[this_axis] = this_axis_size # replace the corresponding axis with the right range
del reshape_target[choice_axis] # remove the choice_axis
list_indices.append(
this_indices.reshape(tuple(reshape_target))
)
tuple_indices = tuple(list_indices)
return array[tuple_indices]
# test it !
array = np.random.random(size=(10, 10, 10, 10))
choices = np.random.randint(10, size=(10, 10, 10))
assert take_indices_along_axis(array, choices, choice_axis=0)[5, 5, 5] == array[choices[5, 5, 5], 5, 5, 5]
assert take_indices_along_axis(array, choices, choice_axis=2)[5, 5, 5] == array[5, 5, choices[5, 5, 5], 5]
Upvotes: -1
Reputation: 221504
Use np.take_along_axis
to index along the second axis -
In [16]: np.take_along_axis(x,choices[:,None],axis=1)[:,0]
Out[16]:
array([[1, 3, 1, 1],
[3, 3, 2, 3],
[3, 2, 3, 3],
[1, 3, 2, 0],
[1, 0, 1, 1],
[1, 3, 1, 3]])
Or with explicit integer-array
indexing -
In [22]: m,n = choices.shape
In [23]: x[np.arange(m)[:,None],choices,np.arange(n)]
Out[23]:
array([[1, 3, 1, 1],
[3, 3, 2, 3],
[3, 2, 3, 3],
[1, 3, 2, 0],
[1, 0, 1, 1],
[1, 3, 1, 3]])
Upvotes: 9