RaviTej310
RaviTej310

Reputation: 1715

Numpy select matrix specified by a matrix of indices, from multidimensional array

I have a numpy array a of size 5x5x4x5x5. I have another matrix b of size 5x5. I want to get a[i,j,b[i,j]] for i from 0 to 4 and for j from 0 to 4. This will give me a 5x5x1x5x5 matrix. Is there any way to do this without just using 2 for loops?

Upvotes: 3

Views: 103

Answers (2)

Divakar
Divakar

Reputation: 221574

There's np.take_along_axis exactly for this purpose -

np.take_along_axis(a,b[:,:,None,None,None],axis=2)

Upvotes: 0

Autonomous
Autonomous

Reputation: 9075

Let's think of the matrix a as 100 (= 5 x 5 x 4) matrices of size (5, 5). So, if you could get a liner index for each triplet - (i, j, b[i, j]) - you are done. That's where np.ravel_multi_index comes in. Following is the code.

import numpy as np
import itertools

# create some matrices
a = np.random.randint(0, 10, (5, 5, 4, 5, 5))
b = np.random(0, 4, (5, 5))

# creating all possible triplets - (ind1, ind2, ind3)
inds = list(itertools.product(range(5), range(5)))
(ind1, ind2), ind3 = zip(*inds), b.flatten()

allInds = np.array([ind1, ind2, ind3])
linearInds = np.ravel_multi_index(allInds, (5,5,4))

# reshaping the input array
a_reshaped = np.reshape(a, (100, 5, 5))

# selecting the appropriate indices
res1 = a_reshaped[linearInds, :, :]

# reshaping back into desired shape
res1 = np.reshape(res1, (5, 5, 1, 5, 5))

# verifying with the brute force method
res2 = np.empty((5, 5, 1, 5, 5))
for i in range(5):
    for j in range(5):
        res2[i, j, 0] = a[i, j, b[i, j], :, :]

print np.all(res1 == res2)  # should print True

Upvotes: 1

Related Questions