Faydey
Faydey

Reputation: 737

numpy array indexing in multiple dimensions

I have a numpy array:

theta = np.random.normal(size = [1000, 30, 80, 10])

I have an index array:

delta = np.empty([1000, 30, 50], dtype = int)
delta[0,:] = np.arange(50) # stand-in for integer indexing

I want to index the theta array using the delta array. That is, I want something like

theta[0,delta[0]]

Where the entry associated with delta[0,20,11] will be theta[0,20,delta[0,20,11]]. E.g., I want to back out the row index. How would I go about this?

Optimally, the output of this would be either have shape [30,50,10] or [30*50, 10].

I'm now using: I'm now using:

theta_unravel = np.repeat(np.arange(30), 50)
theta[0,theta_unravel,delta[0].ravel()]

which appears to work (delivers [30 * 50, 10] result), but I don't know if this is an ideal solution. Rather, I don't know if this is the fastest solution.

Upvotes: 0

Views: 646

Answers (1)

hpaulj
hpaulj

Reputation: 231665

I think this illustrates what you are trying to - with a smaller 3d array:

In [758]: x = np.arange(24).reshape(3,4,2)
In [759]: y = np.ones((3,3),int)
In [760]: x[np.arange(3)[:,None],np.arange(3),y]
Out[760]: 
array([[ 1,  3,  5],
       [ 9, 11, 13],
       [17, 19, 21]])

y is (3,3), to the indices for the other dimensions must broadcast to the same shape.

In [761]: x
Out[761]: 
array([[[ 0,  1],
        [ 2,  3],
        [ 4,  5],
        [ 6,  7]],

       [[ 8,  9],
        [10, 11],
        [12, 13],
        [14, 15]],

       [[16, 17],
        [18, 19],
        [20, 21],
        [22, 23]]])
In [762]: x[np.arange(3)[:,None],np.arange(3),y]=0
In [763]: x
Out[763]: 
array([[[ 0,  0],
        [ 2,  0],
        [ 4,  0],
        [ 6,  7]],

       [[ 8,  0],
        [10,  0],
        [12,  0],
        [14, 15]],

       [[16,  0],
        [18,  0],
        [20,  0],
        [22, 23]]])

Since I used ones to create y (too lazy to do something fancier), that's the same as:

In [765]: x[:,:,1]
Out[765]: 
array([[ 0,  0,  0,  7],
       [ 0,  0,  0, 15],
       [ 0,  0,  0, 23]])
In [766]: x[:,:3,1]
Out[766]: 
array([[0, 0, 0],
       [0, 0, 0],
       [0, 0, 0]])

To be used this way y has to have values in range of the last dimension, here 2.

Upvotes: 1

Related Questions