Binh Thien
Binh Thien

Reputation: 395

Overfitting with curve.fit

can anyone help me struggle with fitting issue from curve.fit. I would like to fit my data to a second order equation. But I obtained a result like a linear equation.enter image description here Here is my code:

import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit

def func(x, a, b, c):
    f = a*np.power(x, 2) + b*x + c
    return f

xdata_prime=[3.0328562996216282, 3.101784841139168, 3.1707134502066894, 3.2396419917242292, 3.308570533241769, 3.3774990747593088, 3.3774990747593088, 3.4337789932367149, 3.4900589392912855, 3.5463388577686916, 3.6026187762460977, 3.6588987223006684]
ydata_prime=[6.344300000000002, 6.723900000000002, 7.080399999999999, 7.399800000000001, 7.649099999999999, 7.753100000000002, 7.753100000000002, 7.658600000000002, 7.442100000000002, 7.180100000000001, 6.902700000000001, 6.6211]

plt.plot(xdata_prime, ydata_prime, 'b-', label='data')
popt, pcov = curve_fit(func, xdata_prime, ydata_prime)
popt
plt.plot(xdata_prime, func(xdata_prime, *popt), 'r-',label='fit')
plt.xlabel('x')
plt.ylabel('y')
plt.legend()
plt.show()

Upvotes: 0

Views: 199

Answers (1)

Fried Noodles
Fried Noodles

Reputation: 297

Your arrays need to be numpy arrays because your function is doing vectorized operations (namely a*np.power(x, 2)). So with this your code will work:

import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit

def func(x, a, b, c):
    f = a*np.power(x, 2) + b*x + c
    return f

xdata_prime=np.array([3.0328562996216282, 3.101784841139168, 3.1707134502066894, 3.2396419917242292, 3.308570533241769, 3.3774990747593088, 3.3774990747593088, 3.4337789932367149, 3.4900589392912855, 3.5463388577686916, 3.6026187762460977, 3.6588987223006684])
ydata_prime=np.array([6.344300000000002, 6.723900000000002, 7.080399999999999, 7.399800000000001, 7.649099999999999, 7.753100000000002, 7.753100000000002, 7.658600000000002, 7.442100000000002, 7.180100000000001, 6.902700000000001, 6.6211])

plt.plot(xdata_prime, ydata_prime, 'b-', label='data')
popt, pcov = curve_fit(func, xdata_prime, ydata_prime)
plt.plot(xdata_prime, func(xdata_prime, *popt), 'r-',label='fit')
plt.xlabel('x')
plt.ylabel('y')
plt.legend()
plt.show()

data

Upvotes: 1

Related Questions