Reputation: 720
I'm working with gradients and having some troubles, here is my code
import jax
def model(x):
return (x+1)**2 + (x-1)**2
def loss(x, y):
return y - model(x)
x = 2
grad = jax.grad(loss, argnums=0)
gradient = grad(x, 0)
And in the last line, I get the following error
TypeError: Gradient only defined for scalar-output functions. Output had shape: (1,). The stack trace below excludes JAX-internal frames. The preceding is the original exception that occurred, unmodified
How can I solve?
Upvotes: 2
Views: 3270
Reputation: 2124
The grad
function requires real or complex-valued inputs but you are using integers. This chunk will work
import jax
def model(x):
return (x+1)**2 + (x-1)**2
def loss(x, y):
return y - model(x)
x=2.0
grad = jax.grad(loss)
gradient = grad(x, 0.0)
Upvotes: 2