Sumanta Roy
Sumanta Roy

Reputation: 21

How to compute batch-wise Jacobians using vmap in JAX?

I want to solve a 2D-differential equation using neural network and working with the JAX library. The neural network function I am using basically approximates the function u = f(x,y) and goes something like this:

def f(params, inputs_x, inputs_y):
  inputs = jnp.concatenate((inputs_x, inputs_y), axis=1)
  for w, b in params:
    outputs = jnp.dot(inputs, w)
    inputs = jnn.swish(outputs)
  return outputs

params is a PyTree that contains the weights and biases matrices. For the 2D problem, let's take layer sizes as something like [2,5,1]. There are 10 batches of (x_inputs, y_inputs) passed onto the function, hence inputs_x, inputs_y both are of shapes (10,1). Therefore, the output I want should also have the shape (10,1). But, the real problem comes when I'm trying to find out du/dx, du/dy, d2u/dx2 or d2u/dy2. I am writing something like this:

u = lambda x,y: f(params, x, y)
    
u = lambda x,y: f(params, x)
u_x = lambda x,y: vmap(jacfwd(u,argnums=0), in_axes=(0,0))(x,y)
u_xx = lambda x,y: vmap(jacfwd(u_x,argnums=0), in_axes=(0,0))(x,y)

I am getting errors.

If I was solving a 1D differential equation, then everything was going fine. In that case, the neural network function is something like this:

def f(params, inputs):
  for w, b in params:
    outputs = jnp.dot(inputs, w)
    inputs = jnn.swish(outputs)
  return outputs
u = lambda x,: f(params, x)
u_x = lambda x: vmap(jacfwd(u,argnums=0))(x)

Layer Sizes are [1,5,1] and I pass 10 batches of inputs into the neural network function and compute the gradients using vmap. Everything works fine!

As soon as I have a 2D problem and two input neurons, the layer sizes become [2,5,1] and then I pass 10 batches of inputs for both x and y together, vmap doesn't work anymore. I wanted to find du/dx, du/dy, d2u/dx2 or d2u/dy2 using the neural network and four functions below, and I expect all the four functions to return me results of shape (10,1), but I am getting error.

Upvotes: 1

Views: 694

Answers (1)

jakevdp
jakevdp

Reputation: 86443

It looks like your function is not compatible with vmap, because it expects explicit batch dimensions. You can fix this by concatenating along axis=-1 rather than axis=1. Then your function calls could look something like the following:

from functools import partial
import jax
import jax.numpy as jnp
from jax import nn as jnn

def f(params, inputs_x, inputs_y):
  inputs = jnp.concatenate((inputs_x, inputs_y), axis=-1)
  for w, b in params:
    outputs = jnp.dot(inputs, w)
    inputs = jnn.swish(outputs)
  return outputs

# Some example inputs and parameters
inputs_x = jnp.ones((10, 1))
inputs_y = jnp.ones((10, 1))
params = [
    (jnp.ones((2, 5)), 1),
    (jnp.ones((5, 1)), 1)
]

u = partial(f, params)

# u: (10,1)->(10,1)
print(u(inputs_x, inputs_y).shape)
# (10, 1)

# u: (1)->(1) batched to (10,1)->(10,1)
print(jax.vmap(u)(inputs_x, inputs_y).shape)
# (10, 1)

# ∇u: (1) -> (1,1) batched to (10,1)->(10,1,1)
print(jax.vmap(jax.jacobian(u))(inputs_x, inputs_y).shape)
# (10, 1, 1)

# ∇²u: (1) -> (1,1,1) batched to (10,1)->(10,1,1,1)
print(jax.vmap(jax.hessian(u))(inputs_x, inputs_y).shape)
# (10, 1, 1, 1)

Upvotes: 0

Related Questions