Reputation: 3097
I have a 3d NumPy array from which I would like to take many slices. These slices will have a length of one or more in the first and second dimensions, whereas the third will be returned in its entirety. The slice should always be 3d.
My attempt at this:
import numpy as np
a = np.zeros((1000, 10, 100))
row_sets = ([19, 20], [21])
col_sets = ([6], [7, 8])
for rows in row_sets:
for cols in col_sets:
b = a[[rows], [cols]]
print(rows, cols, b.shape)
The results:
[19, 20] [6] (1, 2, 100)
[19, 20] [7, 8] (1, 2, 100)
[21] [6] (1, 1, 100)
[21] [7, 8] (1, 2, 100)
If I remove the nested brackets from the slice:
b = a[rows, cols]
I have what appears to be the same issue in the second dimension, and dimensionality is not preserved:
[19, 20] [6] (2, 100)
[19, 20] [7, 8] (2, 100)
[21] [6] (1, 100)
[21] [7, 8] (2, 100)
The result I am looking for would be like this:
[19, 20] [6] (2, 1, 100)
[19, 20] [7, 8] (2, 2, 100)
[21] [6] (1, 1, 100)
[21] [7, 8] (1, 2, 100)
Upvotes: 1
Views: 175
Reputation: 215117
You are triggering advanced indexing by using list of integers as the index, which reduces the dimensions of the result array, if you want to slice the array still, you can use np.ix_ to reconstruct the slice index from list of ints:
for rows in row_sets:
for cols in col_sets:
b = a[np.ix_(rows, cols)]
print(rows, cols, b.shape)
#[19, 20] [6] (2, 1, 100)
#[19, 20] [7, 8] (2, 2, 100)
#[21] [6] (1, 1, 100)
#[21] [7, 8] (1, 2, 100)
Upvotes: 2