Marcos Vinicius
Marcos Vinicius

Reputation: 1

How to use Jax/Autograd in jacobian-computation for scipy.optimize?

I want to use scipy.optimize to globally minimize a multivariate function in python. This multivariate function takes a vector as input and outputs a number. Some minimization methods, used under the hood in global optimization, are more efficient when one provides the jacobian of the function under minimization. I'm trying to use Jax/Autograd to compute this jacobian but always get TracerArrayConversionError when using it inside the scipy.optimize. How can I use Jax/autograd to compute the jacobian as pass it to the minimization method??

The link the error message suggests did not elucidated what I could do to fix the problem I'm dealing with right now. Some observations:

  1. It is worth mentioning that I'm not a frequent user of the Jax/Autograd libraries. I'm just using them to do this task for me now, so I'm not much familiarised with their capabilities.

  2. I do need to compute this jacobian numerically since, in the future, I intend to use more complicated methods to compute the derivative via the deriv method.

I appreciate any alternative to Jax/Autograd usage. I tried wrapping the jacfwd(energy) function to make it readable by the scipy.optimize but still could not fix the problem.

Here is my code

import numpy as np
import scipy.optimize as optimize
from fivepointstencil import derivative as deriv

n = 100
x = np.linspace(-20,20, n, endpoint=True)
dx = np.diff(x)[0]

# function under minimization
def energy(phi):
    return sum( ( (0.5*(deriv(x,phi)**2)) + 0.5*((1 - phi**2)**2) )*dx)

bounds = [(-1,1) for i in range(n)]
bounds[0] = (-1.001,-0.999) #the boundary values for phi are fixed
bounds[-1] = (0.99,1.01)

# HERE I DEFINE THE JACOBIAN FUNCTION
from jax import jacfwd
def jac(phi):
    jac = jacfwd(energy)
    final_jac = jac(phi)
    return np.array(final_jac)

results = optimize.dual_annealing(energy, bounds=bounds, initial_temp=5e4,
                                minimizer_kwargs={'method':'TNC',
                                'bounds':bounds, 'jac':jac})

Where the deriv function just differentiates and array using the five point stencil technique.

And the error I'm getting is

jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float32[100] See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

Upvotes: 0

Views: 45

Answers (1)

bigmacsetnotenough
bigmacsetnotenough

Reputation: 112

Your energy function contains NumPy operations (sum, linspace, etc.) and a custom deriv function that likely uses NumPy, which breaks JAX’s tracing,that's why there is an error.

Try to replace all NumPy operations with jax.numpy in the energy function and deriv.

Upvotes: 1

Related Questions