Jean-Eric
Jean-Eric

Reputation: 402

Jax vectorization: vmap and/or numpy.vectorize?

what are the differences between jax.numpy.vectorizeand jax.vmap? Here is a small snipset

import jax
import jax.numpy as jnp

def f(x):
     return jnp.exp(-x)*jnp.sin(x)

gf = jax.grad(f)
x = jnp.arange(0,1,0.1)

jax.vmap(gf)(x)
jnp.vectorize(gf)(x)

Both computations give the same results:

DeviceArray([ 1. , 0.80998397, 0.63975394, 0.4888039 , 0.35637075, 0.24149445, 0.14307144, 0.05990037, -0.00927836, -0.06574923], dtype=float32)

How to decide which one to use, and if there is a difference in terms of performance?

Upvotes: 7

Views: 3778

Answers (1)

jakevdp
jakevdp

Reputation: 86300

jax.vmap and jax.numpy.vectorize have quite different semantics, and only happen to be similar in the case of a single 1D input as in your example.

The purpose of jax.vmap is to map a function over one or more inputs along a single explicit axis, as specified by the in_axes parameter. On the other hand, jax.numpy.vectorize maps a function over one or more inputs along zero or more implicit axes according to numpy broadcasting rules.

To see the difference, let's pass two 2-dimensional inputs and print the shape within the function:

import jax
import jax.numpy as jnp

def print_shape(x, y):
  print(f"x.shape = {x.shape}")
  print(f"y.shape = {y.shape}")
  return x + y

x = jnp.zeros((20, 10))
y = jnp.zeros((20, 10))

_ = jax.vmap(print_shape)(x, y)
# x.shape = (10,)
# y.shape = (10,)

_ = jnp.vectorize(print_shape)(x, y)
# x.shape = ()
# y.shape = ()

Notice that vmap only maps along the first axis, while vectorize maps along both input axes.

And notice also that the implicit mapping of vectorize means it can be used much more flexibly; for example:

x2 = jnp.arange(10)
y2 = jnp.arange(20).reshape(20, 1)

def add(x, y):
  # vectorize always maps over all axes, such that the function is applied elementwise
  assert x.shape == y.shape == ()
  return x + y

jnp.vectorize(add)(x2, y2).shape
# (20, 10)

vectorize will iterate over all axes of the inputs according to numpy broadcasting rules. On the other hand, vmap cannot handle this by default:

jax.vmap(add)(x2, y2)
# ValueError: vmap got inconsistent sizes for array axes to be mapped:
# arg 0 has shape (10,) and axis 0 is to be mapped
# arg 1 has shape (20, 1) and axis 0 is to be mapped
# so
# arg 0 has an axis to be mapped of size 10
# arg 1 has an axis to be mapped of size 20

To accomplish this same operation with vmap requires more thought, because there are two separate mapped axes, and some of the axes are broadcast. But you can accomplish the same thing this way:

jax.vmap(jax.vmap(add, in_axes=(None, 0)), in_axes=(0, None))(x2, y2[:, 0]).shape
# (20, 10)

This latter nested vmap is essentially what is happening under the hood when you use jax.numpy.vectorize.

As for which to use in any given situation:

  • if you want to map a function across a single, explicitly specified axis of the inputs, use jax.vmap
  • if you want a function's inputs to be mapped across zero or more axes according to numpy's broadcasting rules as applied to the input, use jax.numpy.vectorize.
  • in situations where the transforms are identical (for example when mapping a function of 1D inputs) lean toward using vmap, because it more directly does what you want to do.

Upvotes: 11

Related Questions