relaxon
relaxon

Reputation: 141

Get batched indices from stacked matrices - Python Jax

I would like to extract the indices of stacked matrices.

Let us say we have an array a of dimension (3, 2, 4), meaning that we have three arrays of dimension (2,4) and a list of indices (3, 2).

def get_cols(x,idx):  
  x = x[:,idx]
  return x


idx = jnp.array([[0,1],[2,3],[1,2]])

a = jnp.array([[[1,2,3,4],
            [3,2,2,4]],
           
           [[100,20,3,50],
            [5,5,2,4]],
                         
           [[1,2,3,4],
            [3,2,2,4]]
           ])



e = jax.vmap(get_cols, in_axes=(None,0))(a,idx)

I want to extract the columns of the different matrices given a batch of indices. I expect the following result:

e = [[[[1,2],
  [3,2]],

  [[100,20],
  [5,5]],

  [[1,2],
  [3,2]]],
 
 
 
 [[[3,4],
  [2,4]],
  
  [[3,50],
  [2,4]],
  
  [[3,4],
  [2,4]]],
 
 
 
 
[[[2,3],
[2,2]],
           
[[20,3],
 [5,2]],
                         
[[2,3],
[2,2]]]]

What am I missing?

Upvotes: 1

Views: 117

Answers (1)

jakevdp
jakevdp

Reputation: 86300

It looks like you're interested in a double vmap over the inputs; e.g. something like this:

e = jax.vmap(jax.vmap(get_cols, in_axes=(0, None)), in_axes=(None, 0))(a, idx)
print(e)
[[[[  1   2]
   [  3   2]]

  [[100  20]
   [  5   5]]

  [[  1   2]
   [  3   2]]]


 [[[  3   4]
   [  2   4]]

  [[  3  50]
   [  2   4]]

  [[  3   4]
   [  2   4]]]


 [[[  2   3]
   [  2   2]]

  [[ 20   3]
   [  5   2]]

  [[  2   3]
   [  2   2]]]]

Upvotes: 1

Related Questions