Mahdi
Mahdi

Reputation: 11

Failing to return gradients

This code is supposed to calculate gradients of the output of a network w.r.t its inputs. but it seems to return wrong values. what is the problem with the code? for more context, the function "B" is supposed to be a diagonal matrix with the following values: [0 0 0 0 d^2c/dx^2 d^2(cu)/dxdy 0 0 d^2c/dxdy d^2(cu)/dy^2]

import jax.numpy as jnp
from jax import jacfwd, jacrev

class NCL(object):
    def __init__(self, network, mass_constant=2):
        self.network = network
        self.mc = mass_constant

    def A(self, x, params):
            u_v = self.network(x,params[0])[:-1]
            c, cu = self.network(x, params[0])[:2]  # Extract c and cu
            dc_dx = jacfwd(lambda x: c)(x)
            print(dc_dx)
            d2c_dx2 = jacfwd(lambda x: dc_dx)(x)[1]
            d2c_dxdy = jacfwd(lambda x: dc_dx)(x)[2]
            dcu_dy = jacfwd(lambda x: cu)(x)[2]
            d2cu_dy2 = jacfwd(lambda x: dcu_dy)(x)[2]
            d2cu_dydx = jacfwd(lambda x: dcu_dy)(x)[1]
            
            I_2 = d2c_dx2 + d2cu_dydx
            I_3 = d2c_dxdy + d2cu_dy2


            I = [[0.0],[I_2],[I_3]]
            print("d",dc_dx)
            N = len(x)
            B = jnp.zeros((N,N))
            diag_idx = jnp.diag_indices(N,1)
            B = B.at[diag_idx].set(I)

           
            A = jnp.zeros((N,N))
            idx = jnp.triu_indices(N,1)
            A = A.at[idx].set(u_v)
            return A - A.T + B

# Example usage
network = lambda x, params: jnp.array([x[0]*x[1]*x[1] + params[0], x[1]*x[1]*x[1] * params[1], x[2] * params[2],x[2]])

ncl_instance = NCL(network)

# Example input
x_input = jnp.array([1.0, 2.0, 3.0])

# Example parameters
params_input = jnp.array([0.1, 0.2, 0.3])

# Test the A function
result_A = ncl_instance.A(x_input, [params_input])

print("Result A:\n", result_A)

Upvotes: 1

Views: 47

Answers (1)

jakevdp
jakevdp

Reputation: 86513

The problem seems to be in your I array. You seem to be attempting to construct a ragged array, which JAX doesn't support. Perhaps this is what you were intending?

I = jnp.array([jnp.zeros_like(I_2), I_2, I_3])

With that change, your code prints this:

[0. 0. 0.]
d [0. 0. 0.]
Result A:
 [[ 0.          4.1         1.6       ]
 [-4.1         0.          0.90000004]
 [-1.6        -0.90000004  0.        ]]

Upvotes: 0

Related Questions