David
David

Reputation: 1267

Obtaining zeros in this derivative in Jax

Implementing a jacobian for a polar to cartesian coordinates, I obtain an array of zeros in Jax, which it can't be

theta = np.pi/4
r = 4.0
    
var = np.array([r, theta])

x = var[0]*jnp.cos(var[1])
y = var[0]*jnp.sin(var[1])

def f(var):
    return np.array([x, y])
    
jac = jax.jacobian(f)(var)
jac

#DeviceArray([[0., 0.],
#             [0., 0.]], dtype=float32)

What am I missing?

Upvotes: 1

Views: 261

Answers (1)

Brutus
Brutus

Reputation: 103

Your function has no dependence on var because x, y are defined outside the function.

This would give the desired output instead:

theta = np.pi/4
r = 4.0
    
var = np.array([r, theta])

def f(var):
    x = var[0]*jnp.cos(var[1])
    y = var[0]*jnp.sin(var[1])
    return jnp.array([x, y])
    
jac = jax.jacobian(f)(var)
jac

Note that you need to return a jax numpy array rather than a numpy array as well.

Upvotes: 2

Related Questions