Reputation: 402
what are the differences between jax.numpy.vectorize
and jax.vmap
?
Here is a small snipset
import jax
import jax.numpy as jnp
def f(x):
return jnp.exp(-x)*jnp.sin(x)
gf = jax.grad(f)
x = jnp.arange(0,1,0.1)
jax.vmap(gf)(x)
jnp.vectorize(gf)(x)
Both computations give the same results:
DeviceArray([ 1. , 0.80998397, 0.63975394, 0.4888039 , 0.35637075, 0.24149445, 0.14307144, 0.05990037, -0.00927836, -0.06574923], dtype=float32)
How to decide which one to use, and if there is a difference in terms of performance?
Upvotes: 7
Views: 3778
Reputation: 86300
jax.vmap
and jax.numpy.vectorize
have quite different semantics, and only happen to be similar in the case of a single 1D input as in your example.
The purpose of jax.vmap
is to map a function over one or more inputs along a single explicit axis, as specified by the in_axes
parameter. On the other hand, jax.numpy.vectorize
maps a function over one or more inputs along zero or more implicit axes according to numpy broadcasting rules.
To see the difference, let's pass two 2-dimensional inputs and print the shape within the function:
import jax
import jax.numpy as jnp
def print_shape(x, y):
print(f"x.shape = {x.shape}")
print(f"y.shape = {y.shape}")
return x + y
x = jnp.zeros((20, 10))
y = jnp.zeros((20, 10))
_ = jax.vmap(print_shape)(x, y)
# x.shape = (10,)
# y.shape = (10,)
_ = jnp.vectorize(print_shape)(x, y)
# x.shape = ()
# y.shape = ()
Notice that vmap
only maps along the first axis, while vectorize
maps along both input axes.
And notice also that the implicit mapping of vectorize
means it can be used much more flexibly; for example:
x2 = jnp.arange(10)
y2 = jnp.arange(20).reshape(20, 1)
def add(x, y):
# vectorize always maps over all axes, such that the function is applied elementwise
assert x.shape == y.shape == ()
return x + y
jnp.vectorize(add)(x2, y2).shape
# (20, 10)
vectorize
will iterate over all axes of the inputs according to numpy broadcasting rules. On the other hand, vmap
cannot handle this by default:
jax.vmap(add)(x2, y2)
# ValueError: vmap got inconsistent sizes for array axes to be mapped:
# arg 0 has shape (10,) and axis 0 is to be mapped
# arg 1 has shape (20, 1) and axis 0 is to be mapped
# so
# arg 0 has an axis to be mapped of size 10
# arg 1 has an axis to be mapped of size 20
To accomplish this same operation with vmap
requires more thought, because there are two separate mapped axes, and some of the axes are broadcast. But you can accomplish the same thing this way:
jax.vmap(jax.vmap(add, in_axes=(None, 0)), in_axes=(0, None))(x2, y2[:, 0]).shape
# (20, 10)
This latter nested vmap
is essentially what is happening under the hood when you use jax.numpy.vectorize
.
As for which to use in any given situation:
jax.vmap
jax.numpy.vectorize
.vmap
, because it more directly does what you want to do.Upvotes: 11