marius
marius

Reputation: 1442

How to use Jax vmap over zipped arguments?

I have the following example code that works with a regular map

def f(x_y):
    x, y = x_y
    return x.sum() + y.sum()

xs = [jnp.zeros(3) for i in range(4)]
ys = [jnp.zeros(2) for i in range(4)]

list(map(f, zip(xs, ys)))

# returns:
[DeviceArray(0., dtype=float32),
 DeviceArray(0., dtype=float32),
 DeviceArray(0., dtype=float32),
 DeviceArray(0., dtype=float32)]

How can I use jax.vmap instead? The naive thing is:

jax.vmap(f)(zip(xs, ys))

but this gives:

ValueError: vmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0 (its shape is ())

Upvotes: 4

Views: 2585

Answers (2)

jakevdp
jakevdp

Reputation: 86453

vmap is designed to map over multiple variables by default, so no zip is needed. Furthermore, it can only map over array axes, not over elements of lists or tuples. So a more canonical way to to write your example would be to convert your lists to arrays and do something like this:

def g(x, y):
  return x.sum() + y.sum()

xs_arr = jnp.asarray(xs)
ys_arr = jnp.asarray(ys)

jax.vmap(g)(xs_arr, ys_arr)
# DeviceArray([0., 0., 0., 0.], dtype=float32)

Upvotes: 3

I'mahdi
I'mahdi

Reputation: 24059

For using jax.vmap, you do not need to zip your variables. You can write what you want like below:

import jax.numpy as jnp
from jax import vmap

def f(x_y):
    x, y = x_y
    return x.sum() + y.sum()

xs = jnp.zeros((4,3))
ys = jnp.zeros((4,2))
vmap(f)((xs, ys))

Output:

DeviceArray([0., 0., 0., 0.], dtype=float32)

Upvotes: 3

Related Questions