Reputation: 7954
How to slice ndarray
by multiple dimensions in one line? Check the last line in the following snippet. This seems so basic yet it gives a surprise... but why?
import numpy as np
# create 4 x 3 array
x = np.random.rand(4, 3)
# create row and column filters
rows = np.array([True, False, True, False])
cols = np.array([True, False, True])
print(x[rows, :].shape == (2, 3)) # True ... OK
print(x[:, cols].shape == (4, 2)) # True ... OK
print(x[rows][:, cols].shape == (2, 2)) # True ... OK
print(x[rows, cols].shape == (2, 2)) # False ... WHY???
Upvotes: 2
Views: 126
Reputation: 58865
Since rows
and cols
are boolean arrays, when you do:
x[rows, cols]
it is like:
x[np.where(rows)[0], np.where(cols)[0]]
which is:
x[[0, 2], [0, 2]]
taking the values at positions (0, 0)
and (2, 2)
. On the other hand, doing:
x[rows][:, cols]
works like:
x[[0, 2]][:, [0, 2]]
returning a shape (2, 2)
in this example.
Upvotes: 4