Reputation: 141
I would like to extract the indices of stacked matrices.
Let us say we have an array a
of dimension (3, 2, 4), meaning that we have three arrays of dimension (2,4) and a list of indices (3, 2).
def get_cols(x,idx):
x = x[:,idx]
return x
idx = jnp.array([[0,1],[2,3],[1,2]])
a = jnp.array([[[1,2,3,4],
[3,2,2,4]],
[[100,20,3,50],
[5,5,2,4]],
[[1,2,3,4],
[3,2,2,4]]
])
e = jax.vmap(get_cols, in_axes=(None,0))(a,idx)
I want to extract the columns of the different matrices given a batch of indices. I expect the following result:
e = [[[[1,2],
[3,2]],
[[100,20],
[5,5]],
[[1,2],
[3,2]]],
[[[3,4],
[2,4]],
[[3,50],
[2,4]],
[[3,4],
[2,4]]],
[[[2,3],
[2,2]],
[[20,3],
[5,2]],
[[2,3],
[2,2]]]]
What am I missing?
Upvotes: 1
Views: 117
Reputation: 86300
It looks like you're interested in a double vmap over the inputs; e.g. something like this:
e = jax.vmap(jax.vmap(get_cols, in_axes=(0, None)), in_axes=(None, 0))(a, idx)
print(e)
[[[[ 1 2]
[ 3 2]]
[[100 20]
[ 5 5]]
[[ 1 2]
[ 3 2]]]
[[[ 3 4]
[ 2 4]]
[[ 3 50]
[ 2 4]]
[[ 3 4]
[ 2 4]]]
[[[ 2 3]
[ 2 2]]
[[ 20 3]
[ 5 2]]
[[ 2 3]
[ 2 2]]]]
Upvotes: 1