Reputation: 51
Lets suppose I have some function which returns a sum of inputs.
@jit
def some_func(a,r1,r2):
return a + r1 + r2
Now I would like to loop over different values of r1
and r2
, save the result and add it to a counter. This is what I mean:
a = 0
r1 = jnp.arange(0,3)
r2 = jnp.arange(0,3)
s = 0
for i in range(len(r1)):
for j in range(len(r2)):
s+= some_func(a, r1[i], r2[j])
print(s)
DeviceArray(18, dtype=int32)
My question is, how do I do this with jax.vmap
to avoid writing the for
loops? I have something like this so far:
vmap(some_func, in_axes=(None, 0,0), out_axes=0)(jnp.arange(0,3), jnp.arange(0,3))
but this gives me the following error:
ValueError: vmap in_axes specification must be a tree prefix of the corresponding value, got specification (None, 0, 0) for value tree PyTreeDef((*, *)).
I have a feeling that the error is in in_axes
but I am not sure how to get vmap
to pick a value for r1
loop over r2
and then do the same for all r1
whilst saving intermediate results.
Any help is appreciated.
Upvotes: 4
Views: 1444
Reputation: 86310
vmap
will map over a single axis at a time. Because you want to map over two different axes, you'll need two vmap
calls:
func_mapped = vmap(vmap(some_func, (None, 0, None)), (None, None, 0))
func_mapped(a, r1, r2).sum()
# 18
Alternatively, for a simple function like this you can avoid vmap
and use numpy-style broadcasting to get the same result:
some_func(a, r1[None, :, None], r2[None, None, :]).sum()
# 18
Upvotes: 4