FraserOfSmeg
FraserOfSmeg

Reputation: 1148

Advanced indexing is returning an array with the wrong shape

I've been using this reference to understand advanced indexing. One specific example is as follows;

Example Suppose x.shape is (10,20,30) and ind is a (2,3,4)-shaped indexing intp array, then result = x[...,ind,:] has shape (10,2,3,4,30) because the (20,)-shaped subspace has been replaced with a (2,3,4)-shaped broadcasted indexing subspace. If we let i, j, k loop over the (2,3,4)-shaped subspace then result[...,i,j,k,:] = x[...,ind[i,j,k],:]. This example produces the same result as x.take(ind, axis=-2).

I've been trying to understand this for a while and to help me I'm got a little script that produces some arrays. I have;

Indexing arrays
i => 12 x 25
j => 12 x 25
k => 12 x 1

Input array
x => 2 x 3 x 4 x 4

Output Array
Cols => 2 x 12 x 25

The code I use to make Cols is as follows;

cols = x[:, k, i, j]

From my understanding of the example cols should actually have shape (2 x 12 x 1 x 12 x 25 x 12 x 25). I've come to this as follows;

It's original dimensions are 2 x 3 x 4 x 4

The 2 is unchanged but all other dimensions are altered

The 3 is replaced with k, a 12 x 1 array

The first 4 is replaced by i, a 12 x 25 array

The second 4 is replaced by j, also a 12 x 25 array

Clearly I'm misunderstanding something here, where am I going wrong?

Upvotes: 1

Views: 193

Answers (1)

Daniel F
Daniel F

Reputation: 14399

This does what you want:

i=np.random.randint(0,4,(12,25))
j=np.random.randint(0,4,(12,25))
k=np.random.randint(0,3,(12,1))
x=np.random.randint(1,11,(2,3,4,4))

x1 = x[:,k,:,:][:,:,:,i,:][:,:,:,:,:,j]
x1.shape

(2, 12, 1, 12, 25, 12, 25)

Why doesn't the original method work that way? I think it is probably that advanced indexing is greedy in determining whether you're indexing by multiple dimensions simultaneously. For instance, your original shape:

x.shape
(2,3,4,4) 

Could be interpreted many ways. What you want is that each axis is independent, but it is just as valid to interpret it as 6 (4,4) matrices or 2 (3,4,4) tensors. So when indexing by [...,i,j] you can interpret the i to be over the third axis and the j over the fourth, or that i,j is over the last two axes. Numpy guesses that you mean the second:

x[...,i,j].shape
(2,3,12,25)

You can also interpret x as 8 (3,4) matrices, which is what happens when you do:

x[:,k,i,:].shape
(2,12,25,4)

Notice that is has also broadcasted your (12,1) k array to (12,25) in order to match i for indexing. You can confirm that broadcasting is happening by using .squeeze() on k:

x[:,k.squeeze(),i,:]
Traceback (most recent call last):
IndexError: shape mismatch: indexing arrays could not be broadcast together with shapes (12,) (12,25)

If you interpret x as 2 (3,4,4) tensors, numpy does both. It broadcasts k to (12,25) and then indexes the last three dimensions against a set of three (12,25) indexing arrays, reducing all three as a unit.

You can override this behavior somewhat using np.ix_, but all the arguments of np.ix_ have to be 1d, so you're out of luck there without flattening and reshaping, which sort of defeats the purpose here, but also works:

x2 = x[np.ix_(np.arange(x.shape[0]), k.flat, i.flat, j.flat)].reshape((x.shape[0], ) + k.shape + i.shape + j.shape)

x2.shape
(2, 12, 1, 12, 25, 12, 25)

np.all(x1 == x2)
True

Upvotes: 1

Related Questions