Reputation: 21
I want to solve a second order differential equation with neural network. For automatic differentiation I am using JAX library. To compute first order and second order derivative of my target variable 'u' i.e to compute du/dx and d2u/dx2 elementwise_grad has been used in an example. In jax what is its alternative?
For example neural network function is evaluating 'u': which is defined as below:
'''
def u(params, inputs):
for Weights, biases in params:
outputs = np.dot(inputs, Weights) + biases
inputs = sigmoid(outputs)
return outputs
'''
u has two arguments: params is the set of weights and biases and inputs is the x range with respect to which u will be differentiated.
suppose x has a length of 50, so size of output u will also be 50*1
Now I have to take differentiation of all 50 values of u at a time. By JAX, which functions should I use to calculate du/dx and d2u/dx2? grad function is not working
dudx = grad(u,1)(x)
d2udx2 = grad(grad(u,1)(x))(x)
These are giving some errors
Upvotes: 2
Views: 1253
Reputation: 86513
This isn't really a function that has a meaningful elementwise gradient. It's mapping one vector space to another vector space, and the appropriate derivative for this kind of operation is a jacobian:
dudx = jax.jacobian(u, 1)(params, x)
The result is a matrix whose entries are the derivative of the ith output with respect to the jth input.
Note that if you had a truly element-wise function and wanted to compute the element-wise gradient, you could do so with vmap
; for example:
def f(x):
return jnp.exp(x) - 1
df_dx = jax.vmap(jax.grad(f))(x)
That doesn't work for your function, because the mapping to the output vector space is determined by the contents of params
, and vmap
cannot easily account for that.
Upvotes: 1