Reputation: 29
This may me a very simple thing, but I was wondering how to perform mapping in the following example.
Suppose we have a function that we want to evaluate derivative with respect to xt
, yt
and zt
, but it also takes additional parameters xs
, ys
and zs
.
import jax.numpy as jnp
from jax import grad, vmap
def fn(xt, yt, zt, xs, ys, zs):
return jnp.sqrt((xt - xs) ** 2 + (yt - ys) ** 2 + (zt - zs) ** 2)
Now, let us define the input data:
xt = jnp.array([1., 2., 3., 4.])
yt = jnp.array([1., 2., 3., 4.])
zt = jnp.array([1., 2., 3., 4.])
xs = jnp.array([1., 2., 3.])
ys = jnp.array([3., 3., 3.])
zs = jnp.array([1., 1., 1.])
In order to evaluate gradient for each pair of data points in xt
, yt
and zt
, I have to do the following:
fn_prime = vmap(grad(fn, argnums=(0, 1, 2)), in_axes=(None, None, None, 0, 0, 0))
a = []
for _xt in xt:
for _yt in yt:
for _zt in zt:
a.append(fn_prime(_xt, _yt, _zt, xs, ys, zs))
and it results in a list of tuples.
Once the list is converted to a jnp.array
, it is of the following shape:
a = jnp.array(a)
print(f`shape = {a.shape}')
shape = (64, 3, 3)
My question is: Is there a way to avoid this for loop and evaluate all gradients in the same sweep?
Upvotes: 3
Views: 2952
Reputation: 141
You could vectorize the gradient and give a grid-like input for the diff arguments:
def f_vect(xt, yt, zt, xs, ys, zs):
fp = jnp.vectorize(grad(fn, argnums=(0, 1, 2)))
fn_prime = vmap(fp, in_axes=(None, None, None, 0, 0, 0))
a = jnp.array(fn_prime(xt[:, None, None],
yt[None, :, None],
zt[None, None, :],
xs, ys, zs)) # has shape 3,3,4,4,4
a = a.reshape(a.shape[0], a.shape[1], -1) # has shape 3,3,64
a = jnp.rollaxis(a, -1) # has shape 64,3,3
return a
Upvotes: 0
Reputation: 86310
A good rule of thumb for cases like this is that each nested for
loop translates to a nested vmap
over an appropriate in_axis
. With this in mind, you can re-express your computation this way:
def f_loops(xt, yt, zt, xs, ys, zs):
a = []
for _xt in xt:
for _yt in yt:
for _zt in zt:
a.append(fn_prime(_xt, _yt, _zt, xs, ys, zs))
return jnp.array(a)
def f_vmap(xt, yt, zt, xs, ys, zs):
f_z = vmap(fn_prime, in_axes=(None, None, 0, None, None, None))
f_yz = vmap(f_z, in_axes=(None, 0, None, None, None, None))
f_xyz = vmap(f_yz, in_axes=(0, None, None, None, None, None))
return jnp.stack(f_xyz(xt, yt, zt, xs, ys, zs), axis=3).reshape(64, 3, 3)
out_loops = f_loops(xt, yt, zt, xs, ys, zs)
out_vmap = f_vmap(xt, yt, zt, xs, ys, zs)
np.testing.assert_allclose(out_loops, out_vmap) # passes
Upvotes: 3