triphook
triphook

Reputation: 3097

Preserving dimensions with variable NumPy slice

I have a 3d NumPy array from which I would like to take many slices. These slices will have a length of one or more in the first and second dimensions, whereas the third will be returned in its entirety. The slice should always be 3d.

My attempt at this:

import numpy as np

a = np.zeros((1000, 10, 100))
row_sets = ([19, 20], [21])
col_sets = ([6], [7, 8])

for rows in row_sets:
    for cols in col_sets:
        b = a[[rows], [cols]]
        print(rows, cols, b.shape)

The results:

[19, 20] [6] (1, 2, 100)
[19, 20] [7, 8] (1, 2, 100)
[21] [6] (1, 1, 100)
[21] [7, 8] (1, 2, 100)

If I remove the nested brackets from the slice:

b = a[rows, cols]

I have what appears to be the same issue in the second dimension, and dimensionality is not preserved:

[19, 20] [6] (2, 100)
[19, 20] [7, 8] (2, 100)
[21] [6] (1, 100)
[21] [7, 8] (2, 100)

The result I am looking for would be like this:

[19, 20] [6] (2, 1, 100)
[19, 20] [7, 8] (2, 2, 100)
[21] [6] (1, 1, 100)
[21] [7, 8] (1, 2, 100)

Upvotes: 1

Views: 175

Answers (1)

akuiper
akuiper

Reputation: 215117

You are triggering advanced indexing by using list of integers as the index, which reduces the dimensions of the result array, if you want to slice the array still, you can use np.ix_ to reconstruct the slice index from list of ints:

for rows in row_sets:
    for cols in col_sets:
        b = a[np.ix_(rows, cols)]
        print(rows, cols, b.shape)

#[19, 20] [6] (2, 1, 100)
#[19, 20] [7, 8] (2, 2, 100)
#[21] [6] (1, 1, 100)
#[21] [7, 8] (1, 2, 100)

Upvotes: 2

Related Questions