antelk
antelk

Reputation: 29

Multiple `vmap` in JAX?

This may me a very simple thing, but I was wondering how to perform mapping in the following example.

Suppose we have a function that we want to evaluate derivative with respect to xt, yt and zt, but it also takes additional parameters xs, ys and zs.

import jax.numpy as jnp
from jax import grad, vmap

def fn(xt, yt, zt, xs, ys, zs):
    return jnp.sqrt((xt - xs) ** 2 + (yt - ys) ** 2 + (zt - zs) ** 2)

Now, let us define the input data:

xt = jnp.array([1., 2., 3., 4.])
yt = jnp.array([1., 2., 3., 4.])
zt = jnp.array([1., 2., 3., 4.])
xs = jnp.array([1., 2., 3.])
ys = jnp.array([3., 3., 3.])
zs = jnp.array([1., 1., 1.])

In order to evaluate gradient for each pair of data points in xt, yt and zt, I have to do the following:

fn_prime = vmap(grad(fn, argnums=(0, 1, 2)), in_axes=(None, None, None, 0, 0, 0))

a = []
for _xt in xt:
    for _yt in yt:
        for _zt in zt:
            a.append(fn_prime(_xt, _yt, _zt, xs, ys, zs))

and it results in a list of tuples. Once the list is converted to a jnp.array, it is of the following shape:

a = jnp.array(a)
print(f`shape = {a.shape}')
shape = (64, 3, 3)

My question is: Is there a way to avoid this for loop and evaluate all gradients in the same sweep?

Upvotes: 3

Views: 2952

Answers (2)

enes dilber
enes dilber

Reputation: 141

You could vectorize the gradient and give a grid-like input for the diff arguments:

def f_vect(xt, yt, zt, xs, ys, zs):
    fp = jnp.vectorize(grad(fn, argnums=(0, 1, 2)))
    fn_prime = vmap(fp, in_axes=(None, None, None, 0, 0, 0))
    
    a = jnp.array(fn_prime(xt[:, None, None], 
                           yt[None, :, None], 
                           zt[None, None, :],
                           xs, ys, zs)) # has shape 3,3,4,4,4
    
    a = a.reshape(a.shape[0], a.shape[1], -1) # has shape 3,3,64
    a = jnp.rollaxis(a, -1) # has shape 64,3,3
    
    return a

Upvotes: 0

jakevdp
jakevdp

Reputation: 86310

A good rule of thumb for cases like this is that each nested for loop translates to a nested vmap over an appropriate in_axis. With this in mind, you can re-express your computation this way:

def f_loops(xt, yt, zt, xs, ys, zs):
  a = []
  for _xt in xt:
    for _yt in yt:
      for _zt in zt:
        a.append(fn_prime(_xt, _yt, _zt, xs, ys, zs))
  return jnp.array(a)

def f_vmap(xt, yt, zt, xs, ys, zs):
  f_z = vmap(fn_prime, in_axes=(None, None, 0, None, None, None))
  f_yz = vmap(f_z, in_axes=(None, 0, None, None, None, None))
  f_xyz = vmap(f_yz, in_axes=(0, None, None, None, None, None))
  return jnp.stack(f_xyz(xt, yt, zt, xs, ys, zs), axis=3).reshape(64, 3, 3)

out_loops = f_loops(xt, yt, zt, xs, ys, zs)
out_vmap = f_vmap(xt, yt, zt, xs, ys, zs)

np.testing.assert_allclose(out_loops, out_vmap)  # passes

Upvotes: 3

Related Questions