Reputation: 21
I want to solve a 2D-differential equation using neural network and working with the JAX library. The neural network function I am using basically approximates the function u = f(x,y) and goes something like this:
def f(params, inputs_x, inputs_y):
inputs = jnp.concatenate((inputs_x, inputs_y), axis=1)
for w, b in params:
outputs = jnp.dot(inputs, w)
inputs = jnn.swish(outputs)
return outputs
params is a PyTree that contains the weights and biases matrices. For the 2D problem, let's take layer sizes as something like [2,5,1]. There are 10 batches of (x_inputs, y_inputs) passed onto the function, hence inputs_x, inputs_y both are of shapes (10,1). Therefore, the output I want should also have the shape (10,1). But, the real problem comes when I'm trying to find out du/dx, du/dy, d2u/dx2 or d2u/dy2. I am writing something like this:
u = lambda x,y: f(params, x, y)
u = lambda x,y: f(params, x)
u_x = lambda x,y: vmap(jacfwd(u,argnums=0), in_axes=(0,0))(x,y)
u_xx = lambda x,y: vmap(jacfwd(u_x,argnums=0), in_axes=(0,0))(x,y)
I am getting errors.
If I was solving a 1D differential equation, then everything was going fine. In that case, the neural network function is something like this:
def f(params, inputs):
for w, b in params:
outputs = jnp.dot(inputs, w)
inputs = jnn.swish(outputs)
return outputs
u = lambda x,: f(params, x)
u_x = lambda x: vmap(jacfwd(u,argnums=0))(x)
Layer Sizes are [1,5,1] and I pass 10 batches of inputs into the neural network function and compute the gradients using vmap. Everything works fine!
As soon as I have a 2D problem and two input neurons, the layer sizes become [2,5,1] and then I pass 10 batches of inputs for both x and y together, vmap doesn't work anymore. I wanted to find du/dx, du/dy, d2u/dx2 or d2u/dy2 using the neural network and four functions below, and I expect all the four functions to return me results of shape (10,1), but I am getting error.
Upvotes: 1
Views: 694
Reputation: 86443
It looks like your function is not compatible with vmap
, because it expects explicit batch dimensions. You can fix this by concatenating along axis=-1
rather than axis=1
. Then your function calls could look something like the following:
from functools import partial
import jax
import jax.numpy as jnp
from jax import nn as jnn
def f(params, inputs_x, inputs_y):
inputs = jnp.concatenate((inputs_x, inputs_y), axis=-1)
for w, b in params:
outputs = jnp.dot(inputs, w)
inputs = jnn.swish(outputs)
return outputs
# Some example inputs and parameters
inputs_x = jnp.ones((10, 1))
inputs_y = jnp.ones((10, 1))
params = [
(jnp.ones((2, 5)), 1),
(jnp.ones((5, 1)), 1)
]
u = partial(f, params)
# u: (10,1)->(10,1)
print(u(inputs_x, inputs_y).shape)
# (10, 1)
# u: (1)->(1) batched to (10,1)->(10,1)
print(jax.vmap(u)(inputs_x, inputs_y).shape)
# (10, 1)
# ∇u: (1) -> (1,1) batched to (10,1)->(10,1,1)
print(jax.vmap(jax.jacobian(u))(inputs_x, inputs_y).shape)
# (10, 1, 1)
# ∇²u: (1) -> (1,1,1) batched to (10,1)->(10,1,1,1)
print(jax.vmap(jax.hessian(u))(inputs_x, inputs_y).shape)
# (10, 1, 1, 1)
Upvotes: 0