Sumanta Roy
Sumanta Roy

Reputation: 21

How to take a derivative of one of the outputs of a neural network (involving batched inputs) with respect to inputs?

I am solving a PDE using a neural network. My neural network is as follows:

def f(params, inputs):
  for w, b in params:
    outputs = jnp.dot(inputs, w) + b
    inputs = jnn.swish(outputs)
  return outputs

The layer architecture of the network is as follows - [1,5,2]. Hence, i have one input neuron and two output neurons. Therefore, if I pass 10 batches of input, I am supposed to get a (10,2) array as output. Now let the output neurons be termed as 'p' and 'q' respectively. How do I find dp/dx, dq/dx? I don't want to pick values from jacobians and hessians, and want to have a more explicit functionality. What I mean is, I want something like this below:

p = lambda inputs: f(params, inputs)[:,0].reshape(-1,1)
q = lambda inputs: f(params, inputs)[:,1].reshape(-1,1)

p_x = lambda inputs: vmap(jacfwd(p,argnums=0))(inputs)
q_x = lambda inputs: vmap(jacfwd(q,argnums=0))(inputs)

k_p_x = lambda inputs: kappa(inputs).reshape(-1,1) * p_x(inputs)
        ##And other calculations proceed..

When I execute p(inputs) it's working as expected (as it should), but as soon as I execute p_x(inputs) I am getting an error: IndexError: Too many indices for array: 2 non-None/Ellipsis indices for dim 1.

How do I get around this?

Upvotes: 1

Views: 403

Answers (1)

jakevdp
jakevdp

Reputation: 86513

The reason you are seeing an index error is that your p function expects a two-dimensional input, and when you wrap it in vmap it means you are effectively passing the function a single one-dimensional row at a time.

You can fix this by changing your function so that it accepts a one-dimensional input, and then use vmap as appropriate to compute the batched result.

Here is a complete example with the modified versions of your functions:

import jax
import jax.numpy as jnp
from jax import nn as jnn
from jax import vmap, jacfwd

def f(params, inputs):
  for w, b in params:
    outputs = jnp.dot(inputs, w)
    inputs = jnn.swish(outputs)
  return outputs

# Some example inputs and parameters
inputs_x = jnp.ones((10, 1))
params = [
    (jnp.ones((1, 5)), 1),
    (jnp.ones((5, 2)), 1)
]

inputs = jnp.arange(10.0).reshape(10, 1)

# p and q map a length-1 input to a length-1 output
p = lambda inputs: f(params, inputs)[0].reshape(1)
q = lambda inputs: f(params, inputs)[1].reshape(1)

p_batched = vmap(p)
q_batched = vmap(q)

p_x = lambda inputs: vmap(jacfwd(p,argnums=0))(inputs)
q_x = lambda inputs: vmap(jacfwd(q,argnums=0))(inputs)

print(p_batched(inputs).shape)
# (10, 1)
print(q_batched(inputs).shape)
# (10, 1)

# Note: since p and q map a size 1 input to size-1 output,
# p_x and q_x compute a sequence of 10 1x1 jacobians.
print(p_x(inputs).shape)
# (10, 1, 1)
print(q_x(inputs).shape)
# (10, 1, 1)

Upvotes: 1

Related Questions