SpecificField
SpecificField

Reputation: 43

Using parameters as bounds for scipy.optimize.curve_fit

I was wondering if it is possible to set bounds for the parameters in curve_fit() such that the bounds are dependent on another parameter. For example, say if I wanted to set the slope of a line to be greater than the intercept.

def linear(x, m, b):
    
    return lambda x: (m*x) + b

def plot_linear(x, y):

    B = ([b, -np.inf], [np.inf, np.inf])
    p, v = curve_fit(linear, x, y, bounds = B)

    xs = np.linspace(min(x), max(x), 1000)
    
    plt.plot(x,y,'.')
    plt.plot(xs, linear(xs, *p), '-')

I know that this doesn't work because the parameter b is not defined before it is called in the bounds, but I am not sure if there is a way to make this work?

Upvotes: 2

Views: 586

Answers (1)

Sandipan Dey
Sandipan Dey

Reputation: 23099

We can always re-parameterize w.r.t. the specific curve-fitting problem. For example, if you wanted to fit y=mx+b s.t. m >= b, it can be re-written as m=b+k*k with another parameter k and we can optimize with the parameters b, k now as follows:

def linear(x, m, b):    
    return m*x + b

def linear2(x, k, b):   # constrained fit, m = b + k**2 >= b 
    return (b+k**2)*x + b

def plot_linear(x, y):    
    p, v = curve_fit(linear, x, y)
    print(p)    
    # [3.1675609  6.01025041]
    p2, v2 = curve_fit(linear2, x, y)
    print(p2)
    # [2.13980283e-05 4.99368661e+00]
    xs = np.linspace(min(x), max(x), 1000)        
    plt.plot(x,y,'.')
    plt.plot(xs, linear(xs, *p), 'r-', label='unconstrained fit')
    plt.plot(xs, linear2(xs, *p2), 'b-', label='constrained (m>b) fit')
    plt.legend()

Now let's fit the curves on following data, using both the constrained and unconstrained fit functions (note the unconstrained optimal fit will have slope less than intercept)

x = np.linspace(0,1,100)
y = 3*x + 5 + 2*np.random.rand(len(x))
plot_linear(x, y)

enter image description here

Upvotes: 2

Related Questions