Reputation: 1442
I have the following example code that works with a regular map
def f(x_y):
x, y = x_y
return x.sum() + y.sum()
xs = [jnp.zeros(3) for i in range(4)]
ys = [jnp.zeros(2) for i in range(4)]
list(map(f, zip(xs, ys)))
# returns:
[DeviceArray(0., dtype=float32),
DeviceArray(0., dtype=float32),
DeviceArray(0., dtype=float32),
DeviceArray(0., dtype=float32)]
How can I use jax.vmap
instead? The naive thing is:
jax.vmap(f)(zip(xs, ys))
but this gives:
ValueError: vmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0 (its shape is ())
Upvotes: 4
Views: 2585
Reputation: 86453
vmap
is designed to map over multiple variables by default, so no zip
is needed. Furthermore, it can only map over array axes, not over elements of lists or tuples. So a more canonical way to to write your example would be to convert your lists to arrays and do something like this:
def g(x, y):
return x.sum() + y.sum()
xs_arr = jnp.asarray(xs)
ys_arr = jnp.asarray(ys)
jax.vmap(g)(xs_arr, ys_arr)
# DeviceArray([0., 0., 0., 0.], dtype=float32)
Upvotes: 3
Reputation: 24059
For using jax.vmap
, you do not need to zip your variables. You can write what you want like below:
import jax.numpy as jnp
from jax import vmap
def f(x_y):
x, y = x_y
return x.sum() + y.sum()
xs = jnp.zeros((4,3))
ys = jnp.zeros((4,2))
vmap(f)((xs, ys))
Output:
DeviceArray([0., 0., 0., 0.], dtype=float32)
Upvotes: 3