
Reputation: 35

optimize.curve_fit does not cover parameter space

I am trying to fit temperature and precipitation data to a periodic function using curve_fit. For some reason, curve_fit does not appear to test the entire parameter space as defined by the bounds parameter. I threw together a little model to demonstrate this.

# First generate some data
import numpy as np

# Seed the random number generator for reproducibility
def test_func(x, b, c, d, e=0.0):
    return e*x + b * np.sin(2*np.pi * x/c + d)

def func_err(x,y):
    for xtst in x:
    return sumxy

x_data = np.linspace(0, 10, num=50)*2*np.pi
y_data = test_func(x_data,5.0,20,-0.20, 0.1) + 1.0*np.random.normal(size=50)

# And plot it
import matplotlib.pyplot as plt

# Now fit a simple sine function to the data
from scipy import optimize

params, params_covariance = optimize.curve_fit(test_func, x_data, y_data,
                                               p0=[1.0, 18, 0.0, 0.0],
                                               bounds=([0.1,5,-5.0, -5.0],[100,100,5.0, 5.0]))

print([params,func_err(y_data,test_func(x_data,params[0],params[1], params[2], params[3]))])

# And plot the resulting curve on the data

plt.figure(figsize=(6, 4))
plt.scatter(x_data, y_data, label='Data')
plt.plot(x_data, test_func(x_data, params[0], params[1],params[2], params[3]),
         label='Fitted function')


With the p0=[1.0, 18, 0.0, 0.0] as given, the rotuine finds a solution just fine, Fig 1 Good Fit but with initial values like p0=[1.0, 10, 0.0, 0.0] it fails pretty dramatically. Fig 2 Bad Fit Why does the routine not cover the range given by the bounds to find its solution?

Upvotes: 0

Views: 482

Answers (1)


Reputation: 1407

I think this is due to the nature of the periodic function. Your parameter c determine the periodicity of the function. When your initial guess of periodicity is far from the right periodicity, the fitting will stuck at a local minimum.

You can think it as if when p0=[1.0, 10, 0.0, 0.0], the fitting algorithm finds a local best fit as you shown in the second figure, which is [ 0.65476428, 11.14188385, -1.09652992, 0.08971854] for [b,c,d,e], it tries to move the parameters by a little around that, but the gradient around it suggest this is the best fit, as if it is in a "valley" of the parmeter space, so it stops iteration at there.

The curve_fit is not going to explore your entire parameter space: it only starts from your initial guess, p0 in this case, and use heuristic methods to find local optima.

If you want to explore the entire parameter space for parameter c, you can implement a simple grid search. For example, you can try all values in between the bounds for c, and do curve_fit for each c value, then pick the one with lest fitting errors.

Here is an example code:

def MSE(params,x_data,y_data):
    "to calclate mean square error, this is the same as your func_error"
    return ((test_func(x_data,*params)-y_data)**2).mean()

besterror = 10000
bestParam = None

for c_ in np.arange(5,100,1):
    # grid search for parameter c between 5 and 100, step size is 1.
    params, params_covariance = optimize.curve_fit(test_func, x_data, y_data,
                                                   p0=[1.0, c_+0.5, 0.0, 0.0],
                                                   bounds=([0.1,c_,-5.0, -5.0],[100,c_+1,5.0, 5.0]))
    error = MSE(params,x_data,y_data)
    if error<besterror:
        besterror = error 
        bestParam = params

params = bestParam

print([params,func_err(y_data,test_func(x_data,params[0],params[1], params[2], params[3]))])

# And plot the resulting curve on the data


plt.figure(figsize=(6, 4))
plt.scatter(x_data, y_data, label='Data')
plt.plot(x_data, test_func(x_data, params[0], params[1],params[2], params[3]),
         label='Fitted function')


Grid search for other parameters is unnecessary in this case, because curve_fit is good enough at finding optimal values for other parameters.

This is kind of a brute force method, there might be libraries that can help you do this in a more efficient way.

Upvotes: 1

Related Questions