Antareep
Antareep

Reputation: 21

What is the alternative of elementwise_grad of autograd in JAX?

I want to solve a second order differential equation with neural network. For automatic differentiation I am using JAX library. To compute first order and second order derivative of my target variable 'u' i.e to compute du/dx and d2u/dx2 elementwise_grad has been used in an example. In jax what is its alternative?

For example neural network function is evaluating 'u': which is defined as below:

'''

def u(params, inputs):
    for Weights, biases in params:
        outputs = np.dot(inputs, Weights) + biases
        inputs = sigmoid(outputs)    
    return outputs

'''

u has two arguments: params is the set of weights and biases and inputs is the x range with respect to which u will be differentiated.

suppose x has a length of 50, so size of output u will also be 50*1

Now I have to take differentiation of all 50 values of u at a time. By JAX, which functions should I use to calculate du/dx and d2u/dx2? grad function is not working

dudx = grad(u,1)(x)
d2udx2 = grad(grad(u,1)(x))(x) 

These are giving some errors

Upvotes: 2

Views: 1253

Answers (1)

jakevdp
jakevdp

Reputation: 86513

This isn't really a function that has a meaningful elementwise gradient. It's mapping one vector space to another vector space, and the appropriate derivative for this kind of operation is a jacobian:

dudx = jax.jacobian(u, 1)(params, x)

The result is a matrix whose entries are the derivative of the ith output with respect to the jth input.

Note that if you had a truly element-wise function and wanted to compute the element-wise gradient, you could do so with vmap; for example:

def f(x):
  return jnp.exp(x) - 1

df_dx = jax.vmap(jax.grad(f))(x)

That doesn't work for your function, because the mapping to the output vector space is determined by the contents of params, and vmap cannot easily account for that.

Upvotes: 1

Related Questions