Him
Him

Reputation: 5551

How to broadcast numpy indexing along batch dimensions?

For example, np.array([[1,2],[3,4]])[np.triu_indices(2)] has shape (3,), being a flattened list of the upper triangular entries. However, if I have a batch of 2x2 matrices:

foo = np.repeat(np.array([[[1,2],[3,4]]]), 30, axis=0)

and I want to obtain the upper triangular indices of each matrix, the naive thing to try would be:

foo[:,np.triu_indices(2)]

However, this object is actually of shape (30,2,3,2) (as opposed to (30,3) that we might expect if we had extracted the upper triangular entries batch-wise.

How can we broadcast tuple indexing along the batch dimensions?

Upvotes: 5

Views: 939

Answers (1)

Divakar
Divakar

Reputation: 221574

Get the tuples and use those to index into the last two dims -

r,c = np.triu_indices(2)
out = foo[:,r,c]

Alternatively, one-liner with Ellipsis that works for both 3D and 2D arrays -

foo[(Ellipsis,)+np.triu_indices(2)]

It will work for 2D arrays similarly -

out = foo[r,c] # foo as 2D input array

Masking way

3D array case

We can also use a mask for a masking based way -

foo[:,~np.tri(2,k=-1, dtype=bool)]

2D array case

foo[~np.tri(2,k=-1, dtype=bool)]

Upvotes: 4

Related Questions