Lorenzo Cutrupi
Lorenzo Cutrupi

Reputation: 720

Gradient only defined for scalar-output functions. Output had shape: (1,)

I'm working with gradients and having some troubles, here is my code

import jax

def model(x):
    return (x+1)**2 + (x-1)**2

def loss(x, y):
    return y - model(x)

x = 2
grad = jax.grad(loss, argnums=0)
gradient = grad(x, 0)

And in the last line, I get the following error

TypeError: Gradient only defined for scalar-output functions. Output had shape: (1,). The stack trace below excludes JAX-internal frames. The preceding is the original exception that occurred, unmodified

How can I solve?

Upvotes: 2

Views: 3270

Answers (1)

amarchin
amarchin

Reputation: 2124

The grad function requires real or complex-valued inputs but you are using integers. This chunk will work

import jax

def model(x):
  return (x+1)**2 + (x-1)**2

def loss(x, y):
  return y - model(x)

x=2.0
grad = jax.grad(loss)
gradient = grad(x, 0.0)

Upvotes: 2

Related Questions