Reputation: 11
This code is supposed to calculate gradients of the output of a network w.r.t its inputs. but it seems to return wrong values. what is the problem with the code? for more context, the function "B" is supposed to be a diagonal matrix with the following values: [0 0 0 0 d^2c/dx^2 d^2(cu)/dxdy 0 0 d^2c/dxdy d^2(cu)/dy^2]
import jax.numpy as jnp
from jax import jacfwd, jacrev
class NCL(object):
def __init__(self, network, mass_constant=2):
self.network = network
self.mc = mass_constant
def A(self, x, params):
u_v = self.network(x,params[0])[:-1]
c, cu = self.network(x, params[0])[:2] # Extract c and cu
dc_dx = jacfwd(lambda x: c)(x)
print(dc_dx)
d2c_dx2 = jacfwd(lambda x: dc_dx)(x)[1]
d2c_dxdy = jacfwd(lambda x: dc_dx)(x)[2]
dcu_dy = jacfwd(lambda x: cu)(x)[2]
d2cu_dy2 = jacfwd(lambda x: dcu_dy)(x)[2]
d2cu_dydx = jacfwd(lambda x: dcu_dy)(x)[1]
I_2 = d2c_dx2 + d2cu_dydx
I_3 = d2c_dxdy + d2cu_dy2
I = [[0.0],[I_2],[I_3]]
print("d",dc_dx)
N = len(x)
B = jnp.zeros((N,N))
diag_idx = jnp.diag_indices(N,1)
B = B.at[diag_idx].set(I)
A = jnp.zeros((N,N))
idx = jnp.triu_indices(N,1)
A = A.at[idx].set(u_v)
return A - A.T + B
# Example usage
network = lambda x, params: jnp.array([x[0]*x[1]*x[1] + params[0], x[1]*x[1]*x[1] * params[1], x[2] * params[2],x[2]])
ncl_instance = NCL(network)
# Example input
x_input = jnp.array([1.0, 2.0, 3.0])
# Example parameters
params_input = jnp.array([0.1, 0.2, 0.3])
# Test the A function
result_A = ncl_instance.A(x_input, [params_input])
print("Result A:\n", result_A)
Upvotes: 1
Views: 47
Reputation: 86513
The problem seems to be in your I
array. You seem to be attempting to construct a ragged array, which JAX doesn't support. Perhaps this is what you were intending?
I = jnp.array([jnp.zeros_like(I_2), I_2, I_3])
With that change, your code prints this:
[0. 0. 0.]
d [0. 0. 0.]
Result A:
[[ 0. 4.1 1.6 ]
[-4.1 0. 0.90000004]
[-1.6 -0.90000004 0. ]]
Upvotes: 0