Reputation: 21
I am solving a PDE using a neural network. My neural network is as follows:
def f(params, inputs):
for w, b in params:
outputs = jnp.dot(inputs, w) + b
inputs = jnn.swish(outputs)
return outputs
The layer architecture of the network is as follows - [1,5,2]. Hence, i have one input neuron and two output neurons. Therefore, if I pass 10 batches of input, I am supposed to get a (10,2) array as output. Now let the output neurons be termed as 'p' and 'q' respectively. How do I find dp/dx, dq/dx? I don't want to pick values from jacobians and hessians, and want to have a more explicit functionality. What I mean is, I want something like this below:
p = lambda inputs: f(params, inputs)[:,0].reshape(-1,1)
q = lambda inputs: f(params, inputs)[:,1].reshape(-1,1)
p_x = lambda inputs: vmap(jacfwd(p,argnums=0))(inputs)
q_x = lambda inputs: vmap(jacfwd(q,argnums=0))(inputs)
k_p_x = lambda inputs: kappa(inputs).reshape(-1,1) * p_x(inputs)
##And other calculations proceed..
When I execute p(inputs)
it's working as expected (as it should), but as soon as I execute p_x(inputs)
I am getting an error: IndexError: Too many indices for array: 2 non-None/Ellipsis indices for dim 1.
How do I get around this?
Upvotes: 1
Views: 403
Reputation: 86513
The reason you are seeing an index error is that your p
function expects a two-dimensional input, and when you wrap it in vmap
it means you are effectively passing the function a single one-dimensional row at a time.
You can fix this by changing your function so that it accepts a one-dimensional input, and then use vmap
as appropriate to compute the batched result.
Here is a complete example with the modified versions of your functions:
import jax
import jax.numpy as jnp
from jax import nn as jnn
from jax import vmap, jacfwd
def f(params, inputs):
for w, b in params:
outputs = jnp.dot(inputs, w)
inputs = jnn.swish(outputs)
return outputs
# Some example inputs and parameters
inputs_x = jnp.ones((10, 1))
params = [
(jnp.ones((1, 5)), 1),
(jnp.ones((5, 2)), 1)
]
inputs = jnp.arange(10.0).reshape(10, 1)
# p and q map a length-1 input to a length-1 output
p = lambda inputs: f(params, inputs)[0].reshape(1)
q = lambda inputs: f(params, inputs)[1].reshape(1)
p_batched = vmap(p)
q_batched = vmap(q)
p_x = lambda inputs: vmap(jacfwd(p,argnums=0))(inputs)
q_x = lambda inputs: vmap(jacfwd(q,argnums=0))(inputs)
print(p_batched(inputs).shape)
# (10, 1)
print(q_batched(inputs).shape)
# (10, 1)
# Note: since p and q map a size 1 input to size-1 output,
# p_x and q_x compute a sequence of 10 1x1 jacobians.
print(p_x(inputs).shape)
# (10, 1, 1)
print(q_x(inputs).shape)
# (10, 1, 1)
Upvotes: 1