Slicing multiple dimensions in a ndarray

How to slice ndarray by multiple dimensions in one line? Check the last line in the following snippet. This seems so basic yet it gives a surprise... but why?

import numpy as np

# create 4 x 3 array
x = np.random.rand(4, 3)

# create row and column filters
rows = np.array([True, False, True, False])
cols = np.array([True, False, True])

print(x[rows, :].shape == (2, 3))  # True ... OK
print(x[:, cols].shape == (4, 2))  # True ... OK
print(x[rows][:, cols].shape == (2, 2))  # True ... OK
print(x[rows, cols].shape == (2, 2))  # False ... WHY???

Upvotes: 2

Views: 126

Answers (1)

Saullo G. P. Castro
Saullo G. P. Castro

Reputation: 58865

Since rows and cols are boolean arrays, when you do:

x[rows, cols]

it is like:

x[np.where(rows)[0], np.where(cols)[0]]

which is:

x[[0, 2], [0, 2]]

taking the values at positions (0, 0) and (2, 2). On the other hand, doing:

x[rows][:, cols]

works like:

x[[0, 2]][:, [0, 2]]

returning a shape (2, 2) in this example.

Upvotes: 4

Related Questions