Reputation: 1148
I've been using this reference to understand advanced indexing. One specific example is as follows;
Example Suppose x.shape is (10,20,30) and ind is a (2,3,4)-shaped indexing intp array, then result = x[...,ind,:] has shape (10,2,3,4,30) because the (20,)-shaped subspace has been replaced with a (2,3,4)-shaped broadcasted indexing subspace. If we let i, j, k loop over the (2,3,4)-shaped subspace then result[...,i,j,k,:] = x[...,ind[i,j,k],:]. This example produces the same result as x.take(ind, axis=-2).
I've been trying to understand this for a while and to help me I'm got a little script that produces some arrays. I have;
Indexing arrays
i => 12 x 25
j => 12 x 25
k => 12 x 1
Input array
x => 2 x 3 x 4 x 4
Output Array
Cols => 2 x 12 x 25
The code I use to make Cols is as follows;
cols = x[:, k, i, j]
From my understanding of the example cols should actually have shape (2 x 12 x 1 x 12 x 25 x 12 x 25). I've come to this as follows;
It's original dimensions are 2 x 3 x 4 x 4
The 2 is unchanged but all other dimensions are altered
The 3 is replaced with k, a 12 x 1 array
The first 4 is replaced by i, a 12 x 25 array
The second 4 is replaced by j, also a 12 x 25 array
Clearly I'm misunderstanding something here, where am I going wrong?
Upvotes: 1
Views: 193
Reputation: 14399
This does what you want:
i=np.random.randint(0,4,(12,25))
j=np.random.randint(0,4,(12,25))
k=np.random.randint(0,3,(12,1))
x=np.random.randint(1,11,(2,3,4,4))
x1 = x[:,k,:,:][:,:,:,i,:][:,:,:,:,:,j]
x1.shape
(2, 12, 1, 12, 25, 12, 25)
Why doesn't the original method work that way? I think it is probably that advanced indexing is greedy in determining whether you're indexing by multiple dimensions simultaneously. For instance, your original shape:
x.shape
(2,3,4,4)
Could be interpreted many ways. What you want is that each axis is independent, but it is just as valid to interpret it as 6 (4,4)
matrices or 2 (3,4,4)
tensors. So when indexing by [...,i,j]
you can interpret the i
to be over the third axis and the j
over the fourth, or that i,j
is over the last two axes. Numpy guesses that you mean the second:
x[...,i,j].shape
(2,3,12,25)
You can also interpret x
as 8 (3,4)
matrices, which is what happens when you do:
x[:,k,i,:].shape
(2,12,25,4)
Notice that is has also broadcasted your (12,1)
k
array to (12,25)
in order to match i
for indexing. You can confirm that broadcasting is happening by using .squeeze()
on k
:
x[:,k.squeeze(),i,:]
Traceback (most recent call last):
IndexError: shape mismatch: indexing arrays could not be broadcast together with shapes (12,) (12,25)
If you interpret x
as 2 (3,4,4)
tensors, numpy does both. It broadcasts k
to (12,25)
and then indexes the last three dimensions against a set of three (12,25)
indexing arrays, reducing all three as a unit.
You can override this behavior somewhat using np.ix_
, but all the arguments of np.ix_
have to be 1d, so you're out of luck there without flattening and reshaping, which sort of defeats the purpose here, but also works:
x2 = x[np.ix_(np.arange(x.shape[0]), k.flat, i.flat, j.flat)].reshape((x.shape[0], ) + k.shape + i.shape + j.shape)
x2.shape
(2, 12, 1, 12, 25, 12, 25)
np.all(x1 == x2)
True
Upvotes: 1