Gulzar
Gulzar

Reputation: 28044

Indexing numpy array with index array of lower dim yields array of higher dim than both

a = np.zeros((5,4,3))
v = np.ones((5, 4), dtype=int)
data = a[v]
shp = data.shape

This code gives shp==(5,4,4,3)

I don't understand why. How can a larger array be output? makes no sense to me and would love an explanation.

Upvotes: 9

Views: 807

Answers (2)

yatu
yatu

Reputation: 88305

This is known as advanced indexing. Advanced indexing allows you to select arbitrary elements in the input array based on an N-dimensional index.

Let's use another example to make it clearer:

a = np.random.randint(1, 5, (5,4,3))
v = np.ones((5, 4), dtype=int)

Say in this case a is:

array([[[2, 1, 1],
        [3, 4, 4],
        [4, 3, 2],
        [2, 2, 2]],

       [[4, 4, 1],
        [3, 3, 4],
        [3, 4, 2],
        [1, 3, 1]],

       [[3, 1, 3],
        [4, 3, 1],
        [2, 1, 4],
        [1, 2, 2]],
        ...

By indexing with an array of np.ones:

print(v)

array([[1, 1, 1, 1],
       [1, 1, 1, 1],
       [1, 1, 1, 1],
       [1, 1, 1, 1],
       [1, 1, 1, 1]])

You will simply be indexing a with 1 along the first axis as many times as v. Putting it in another way, when you do:

a[1]

[[4, 4, 1],
 [3, 3, 4],
 [3, 4, 2],
 [1, 3, 1]]

You're indexing along the first axis, as no indexing is specified along the additional axes. It is the same as doing a[1, ...], i.e taking a full slice along the remaining axes. Hence by indexing with a 2D array of ones, you will have the above 2D array (5, 4) times stacked together, resulting in an ndarray of shape (5, 4, 4, 3). Or in other words, a[1], of shape (4,3), stacked 5*4=20 times.

Hence, in this case you'd be getting:

array([[[[4, 4, 1],
         [3, 3, 4],
         [3, 4, 2],
         [1, 3, 1]],

        [[4, 4, 1],
         [3, 3, 4],
         [3, 4, 2],
         [1, 3, 1]],
         ...

Upvotes: 6

Adam.Er8
Adam.Er8

Reputation: 13413

the value of v is:

[[1 1 1 1]
 [1 1 1 1]
 [1 1 1 1]
 [1 1 1 1]
 [1 1 1 1]]

every single 1 indexes a complete "row" in a, but every "element" in said "row" is a matrix. so every "row" in v indexes a "row" of "matrix"es in a. (does this make any sense to you..?)

so you get 5 * 4 1s, each is a 4*3 "matrix".

if instead of zeroes you define a as a = np.arange(5*4*3).reshape((5, 4, 3)) it might be easier to understand, because you get to see which parts of a are being chosen:

import numpy as np

a = np.arange(5*4*3).reshape((5, 4, 3))
v = np.ones((5,4), dtype=int)
data = a[v]
print(data)

(output is pretty long, I don't want to paste it here)

Upvotes: 0

Related Questions