YotamH
YotamH

Reputation: 95

Numpy advanced indexing usage

Case 1 (solved): Array A has shape (say) (300,50). Array B is an indices array with the shape (300,5), such that B[i,j] indicate for the row i the index of another row to "concate" next to the row i. The end result is an array C with the shape (300,5,50), such that C[i,j,:] = A[B[i,j],:]. This can be done by calling A[B,:].

Here is small script example for case 1:

import numpy as np

## A is the data array
A = np.arange(20).reshape((5,4))
## B indicate for each row which rows to pull together
B = np.array([[0,2],[1,2],[2,0],[3,4],[4,1]])
A[B,:] #The desired result

Case 2 (unsolved): Same problem, only now A is shaped (100,300,50). If B is the indicies matrix shaped (100,300,5), the end result would be an array C with the shape (100,300,5,50) such that C[i,j,k,:] = A[i,B[i,j,k],:]. A[B,:] doesn't work anymore, because it result with a shape (100,300,5,300,50), due to broadcasting.

How should I approach this with indexing?

Upvotes: 1

Views: 80

Answers (1)

Divakar
Divakar

Reputation: 221584

One approach would be reshaping to 2D keeping the number of columns intact and then indexing into the first axis with the flattened B indices and finally reshaping back to the desired one.

Thus, the implementation would be -

A.reshape(-1,A.shape[-1])[B.ravel()].reshape(100,300,5,50)

Those reshaping being merely views into the arrays, should be quite efficient.

This solves both cases. Here's a sample run for the case #1 -

1) Inputs :

In [667]: A = np.random.rand(3,4)
     ...: B = np.random.randint(0,3,(3,5))
     ...: 

2) Original method :

In [668]: A[B,:]
Out[668]: 
array([[[ 0.1 ,  0.91,  0.1 ,  0.98],
        [ 0.1 ,  0.91,  0.1 ,  0.98],
        [ 0.1 ,  0.91,  0.1 ,  0.98],
        [ 0.45,  0.16,  0.02,  0.02],
        [ 0.1 ,  0.91,  0.1 ,  0.98]],

       [[ 0.45,  0.16,  0.02,  0.02],
        [ 0.48,  0.6 ,  0.96,  0.21],
        [ 0.48,  0.6 ,  0.96,  0.21],
        [ 0.1 ,  0.91,  0.1 ,  0.98],
        [ 0.45,  0.16,  0.02,  0.02]],

       [[ 0.48,  0.6 ,  0.96,  0.21],
        [ 0.45,  0.16,  0.02,  0.02],
        [ 0.48,  0.6 ,  0.96,  0.21],
        [ 0.45,  0.16,  0.02,  0.02],
        [ 0.45,  0.16,  0.02,  0.02]]])

3) Proposed method :

In [669]: A.reshape(-1,A.shape[-1])[B.ravel()].reshape(3,5,4)
Out[669]: 
array([[[ 0.1 ,  0.91,  0.1 ,  0.98],
        [ 0.1 ,  0.91,  0.1 ,  0.98],
        [ 0.1 ,  0.91,  0.1 ,  0.98],
        [ 0.45,  0.16,  0.02,  0.02],
        [ 0.1 ,  0.91,  0.1 ,  0.98]],

       [[ 0.45,  0.16,  0.02,  0.02],
        [ 0.48,  0.6 ,  0.96,  0.21],
        [ 0.48,  0.6 ,  0.96,  0.21],
        [ 0.1 ,  0.91,  0.1 ,  0.98],
        [ 0.45,  0.16,  0.02,  0.02]],

       [[ 0.48,  0.6 ,  0.96,  0.21],
        [ 0.45,  0.16,  0.02,  0.02],
        [ 0.48,  0.6 ,  0.96,  0.21],
        [ 0.45,  0.16,  0.02,  0.02],
        [ 0.45,  0.16,  0.02,  0.02]]])

Upvotes: 2

Related Questions