Bo Ni
Bo Ni

Reputation: 595

Select rows from 3-d nd-array

the situation is as following: If I have a numpy array of shape (64, 100, 300) and I want to transform it to (64, 1, 300) based on an array of indices of shape 64, what should I do? Say we have

a=np.random.randn(64, 100, 300)
indices = np.random.randint(low=0, high=100, size=64)

I currently do

a[:, indices, :]

which does not work. The returning array is of shape (64, 64, 300) since it tries to select the series of values from every batch.

Upvotes: 2

Views: 79

Answers (1)

Crazy Coder
Crazy Coder

Reputation: 414

Like the comment above suggests:

a[np.arange(indices.size),indices,None]

Or equally but more readable:

a[np.arange(indices.size),indices][:,None,:]

Upvotes: 3

Related Questions