Reputation: 113
I have been trying to perform a simple operation, but I can't seem to find a simple way to do it using Numpy functions without creating unnecessary copies of the array.
Suppose we have the following 3-dimensional array :
In [171]: x = np.arange(24).reshape((4, 3, 2))
In [172]: x
Out[172]:
array([[[ 0, 1],
[ 2, 3],
[ 4, 5]],
[[ 6, 7],
[ 8, 9],
[10, 11]],
[[12, 13],
[14, 15],
[16, 17]],
[[18, 19],
[20, 21],
[22, 23]]])
And the following array :
In [173]: y = np.array([0, 1, 1, 0])
I want to select in x
, for each row, the value of the last dimension whose index is the corresponding element in y
. In other words, I want :
array([[ 0, 2, 4],
[ 7, 9, 11],
[13, 15, 17],
[18, 20, 22]])
The only solution that I have for now is using a for loop over the first dimension of x
and y
, as follows :
z = np.zeros((4, 3), dtype=int)
for i, row in enumerate(x):
z[i, :] = row[:, y[i]]
Is there a way of avoiding a for loop here, using numpy functions or fancy indexing?
Thanks!
Upvotes: 2
Views: 96
Reputation: 85612
Fancy indexing:
x[np.arange(y.size),:,y]
gives:
array([[ 0, 2, 4],
[ 7, 9, 11],
[13, 15, 17],
[18, 20, 22]])
Upvotes: 1
Reputation: 18677
The tricky aspect is that you don't want all of the 0th-dimension for each slice, you want the slices to correspond to each element in the 0th-dimension. So you could do something like:
>>> x[np.arange(x.shape[0]), :, y]
array([[ 0, 2, 4],
[ 7, 9, 11],
[13, 15, 17],
[18, 20, 22]])
Upvotes: 5