Zohim
Zohim

Reputation: 51

vmap in Jax to loop over arguments

Lets suppose I have some function which returns a sum of inputs.

@jit
def some_func(a,r1,r2):
    return a + r1 + r2

Now I would like to loop over different values of r1 and r2, save the result and add it to a counter. This is what I mean:

a = 0 
r1 = jnp.arange(0,3)
r2 = jnp.arange(0,3)


s = 0 
for i in range(len(r1)): 
    for j in range(len(r2)): 
        s+= some_func(a, r1[i], r2[j])
    
print(s)
DeviceArray(18, dtype=int32)

My question is, how do I do this with jax.vmap to avoid writing the for loops? I have something like this so far:

vmap(some_func, in_axes=(None, 0,0), out_axes=0)(jnp.arange(0,3), jnp.arange(0,3))

but this gives me the following error:

ValueError: vmap in_axes specification must be a tree prefix of the corresponding value, got specification (None, 0, 0) for value tree PyTreeDef((*, *)).

I have a feeling that the error is in in_axes but I am not sure how to get vmap to pick a value for r1 loop over r2 and then do the same for all r1 whilst saving intermediate results.

Any help is appreciated.

Upvotes: 4

Views: 1444

Answers (1)

jakevdp
jakevdp

Reputation: 86310

vmap will map over a single axis at a time. Because you want to map over two different axes, you'll need two vmap calls:

func_mapped = vmap(vmap(some_func, (None, 0, None)), (None, None, 0))
func_mapped(a, r1, r2).sum()
# 18

Alternatively, for a simple function like this you can avoid vmap and use numpy-style broadcasting to get the same result:

some_func(a, r1[None, :, None], r2[None, None, :]).sum()
# 18

Upvotes: 4

Related Questions