Vadim Kantorov
Vadim Kantorov

Reputation: 1144

Inconsistent advanced indexing in NumPy

Why are the following indexing forms produce differently shaped outputs?

a = np.zeros((5, 5, 5, 5))
print(a[:, :, [1, 2], [3, 4]].shape)
# (5, 5, 2)

print(a[:, :, 1:3, [3, 4]].shape)
#(5, 5, 2, 2)

Almost certain I'm missing something obvious.

Upvotes: 7

Views: 470

Answers (4)

NeoZoom.lua
NeoZoom.lua

Reputation: 2921

In the first case: Both [1,2] and [3,4] are of shape (2,), which together result in a single (array-)dimension of shape (2,). So in the first result, you got (5,5,2), where the last (2,) is newly created during the process.

On the second case: the only list [3,4] itself results in one (array-)dimension of shape (2,). And the slicing 1:3 only changes the length of its own (array-)dimension into 2. Thus the result (5,5,2,2).

Upvotes: 0

R. S. Nikhil Krishna
R. S. Nikhil Krishna

Reputation: 4250

The first one,

a[:, :, [1, 2], [3, 4]]

takes indices pairwise and selects the following subarrays:

a[:, :, 1, 3]
a[:, :, 2, 4]

whereas the second one generates all possible combos (and shapes it accordingly), i.e.

a[:, :, 1, 3]
a[:, :, 1, 4]
a[:, :, 2, 3]
a[:, :, 2, 4]

This can be verified by running the following exercise. Rather than initializing a as a zero array, use np.arange and reshape it

a = np.arange(5**4).reshape((5, 5, 5, 5))
print(a[:, :, [1, 2], [3, 4]])

The first few lines of the output are

[[[  8  14]
  [ 33  39]
  [ 58  64]...

and the array a itself is

[[[[  0   1   2   3   4]
   [  5   6   7   8   9]
   [ 10  11  12  13  14]
   [ 15  16  17  18  19]
   [ 20  21  22  23  24]]...

So 8 comes at (1,3) (In the innermost 2D array, 1: 2nd row, 3:4th column) as expected and 14 comes at (2, 4). Similarly, 33 is also at index (1,3) and 39 at (2,4) in the next 2D subarray.

Upvotes: 0

Jonas Adler
Jonas Adler

Reputation: 10799

When you have several lists in advanced indexing, it indicates that those should be taken pairwise. By comparision, when you use slices, you get each in the lists for all in the slice.

To see the difference, consider:

>>> print(a[:, :, [1, 2, 3], [3, 4]].shape)
IndexError: shape mismatch: indexing arrays could not be broadcast together with shapes (3,) (2,) 

This is since the first list has length 3, and the second length 2. These do not match, and you get an error.

By comparison, if you use a slice it works perfectly:

>>> print(a[:, :, 1:4, [3, 4]].shape)
(5, 5, 3, 2)

Explanation

To see why this is the case, we consult the numpy indexing documentation, which states:

When the index consists of as many integer arrays as the array being indexed has dimensions, the indexing is straight forward, but different from slicing.

Advanced indexes always are broadcast and iterated as one:

result[i_1, ..., i_M] == x[ind_1[i_1, ..., i_M], ind_2[i_1, ..., i_M],
                       ..., ind_N[i_1, ..., i_M]]

Upvotes: 0

user2357112
user2357112

Reputation: 281843

[1, 2], [3, 4] doesn't mean "select indices 1 and 2 in one dimension and 3 and 4 in another". It means "select the pairs of indices (1, 3) and (2, 4)".

Your first expression select all elements at locations of the form a, b, c, d where a and b can be any index and c and d must be either the pair (1, 3) or the pair (2, 4).

Your second expression selects all elements at locations of the form a, b, c, d where a and b can be any index, c must be in the half-open range [1, 3), and d must be either 3 or 4. Unlike the first one, c and d are allowed to be (2, 3) or (1, 4).


Note that using both basic and advanced indexing in the same indexing expression (which mostly means mixing : and advanced indexing) has unintuitive effects on the order of the axes of the result. It's best to avoid mixing them.

Upvotes: 5

Related Questions