Reputation: 28044
a = np.zeros((5,4,3))
v = np.ones((5, 4), dtype=int)
data = a[v]
shp = data.shape
This code gives shp==(5,4,4,3)
I don't understand why. How can a larger array be output? makes no sense to me and would love an explanation.
Upvotes: 9
Views: 807
Reputation: 88305
This is known as advanced indexing. Advanced indexing allows you to select arbitrary elements in the input array based on an N-dimensional index.
Let's use another example to make it clearer:
a = np.random.randint(1, 5, (5,4,3))
v = np.ones((5, 4), dtype=int)
Say in this case a
is:
array([[[2, 1, 1],
[3, 4, 4],
[4, 3, 2],
[2, 2, 2]],
[[4, 4, 1],
[3, 3, 4],
[3, 4, 2],
[1, 3, 1]],
[[3, 1, 3],
[4, 3, 1],
[2, 1, 4],
[1, 2, 2]],
...
By indexing with an array of np.ones
:
print(v)
array([[1, 1, 1, 1],
[1, 1, 1, 1],
[1, 1, 1, 1],
[1, 1, 1, 1],
[1, 1, 1, 1]])
You will simply be indexing a
with 1
along the first axis as many times as v
. Putting it in another way, when you do:
a[1]
[[4, 4, 1],
[3, 3, 4],
[3, 4, 2],
[1, 3, 1]]
You're indexing along the first axis, as no indexing is specified along the additional axes. It is the same as doing a[1, ...]
, i.e taking a full slice along the remaining axes. Hence by indexing with a 2D
array of ones, you will have the above 2D
array (5, 4)
times stacked together, resulting in an ndarray of shape (5, 4, 4, 3)
. Or in other words, a[1]
, of shape (4,3)
, stacked 5*4=20
times.
Hence, in this case you'd be getting:
array([[[[4, 4, 1],
[3, 3, 4],
[3, 4, 2],
[1, 3, 1]],
[[4, 4, 1],
[3, 3, 4],
[3, 4, 2],
[1, 3, 1]],
...
Upvotes: 6
Reputation: 13413
the value of v
is:
[[1 1 1 1]
[1 1 1 1]
[1 1 1 1]
[1 1 1 1]
[1 1 1 1]]
every single 1
indexes a complete "row" in a
, but every "element" in said "row" is a matrix. so every "row" in v
indexes a "row" of "matrix"es in a
.
(does this make any sense to you..?)
so you get 5 * 4 1
s, each is a 4*3 "matrix".
if instead of zeroes
you define a
as a = np.arange(5*4*3).reshape((5, 4, 3))
it might be easier to understand, because you get to see which parts of a
are being chosen:
import numpy as np
a = np.arange(5*4*3).reshape((5, 4, 3))
v = np.ones((5,4), dtype=int)
data = a[v]
print(data)
(output is pretty long, I don't want to paste it here)
Upvotes: 0