Reputation: 141
How can we extract the rows of a matrix given a batch of indices (in Python)?
i = [[0,1],[1,2],[2,3]]
a = jnp.array([[1,2,3,4],[2,3,4,5]])
def extract(A,idx):
A = A[:,idx]
return A
B = extract(a,i)
I expect to get this result (where the matrices are stacked):
B = [[[1,2],
[2,3]],
[[2,3],
[3,4]],
[3,4],
[4,5]]]
And NOT:
B_ = [[1, 2],
[2, 3],
[3, 4]],
[[2, 3],
[3 ,4],
[4, 5]]]
In this case, the rows are stacked, but I want to stack the different matrices.
I tried using
jax.vmap(extract)(a,i),
but this gives me an error since a and i don't have the same dimension.... Is there an alternative, without using loops?
Upvotes: 2
Views: 88
Reputation: 15482
You can use indexing right away on the matrix a
transposed:
a.T[i,:]
Upvotes: 1
Reputation: 86300
You can do this with vmap
if you specify in_axes
in the right way, and convert your index list into an index array:
vmap(extract, in_axes=(None, 0))(a, jnp.array(i))
# DeviceArray([[[1, 2],
# [2, 3]],
#
# [[2, 3],
# [3, 4]],
#
# [[3, 4],
# [4, 5]]], dtype=int32)
When you say in_axes=(None, 0)
, it specifies that you want the first argument to be unmapped, and you want the second argument to be mapped along its leading axis.
The reason you need to convert i
from a list to an array is because JAX will only map over array arguments: if vmap
encounters a collection like a list, tuple, dict, or a general pytree, it attempts to map over each array-like value within the collection.
Upvotes: 2