Allan Sun
Allan Sun

Reputation: 1

fitting multivariate curve_fit in python (logistic function)

I am fitting a logistic function with X shape: (5,2)

def logifunc(x,A,x0,k):
     return A / (1 + np.exp(-k*(x-x0)))

popt, pcov = curve_fit(logifunc, X, y)

My X is:

array([[ 0.000e+00,  1.000e+00,  2.000e+00,  3.000e+00,  4.000e+00],
       [-1.000e+00,  5.480e+02,  6.430e+02,  9.200e+02,  1.406e+03]])

y is:

array([ 548.,  643.,  920., 1406., 2075.])

I got error message:

TypeError                                 Traceback (most recent call last)
<ipython-input-68-acf83ee46c52> in <module>()
     14 X = np.vstack((t.ravel(), np.array(P).ravel()))
     15 
---> 16 popt, pcov = curve_fit(logifunc, X, y)

1 frames
/usr/local/lib/python3.6/dist-packages/scipy/optimize/minpack.py in leastsq(func, x0, args, Dfun, full_output, col_deriv, ftol, xtol, gtol, maxfev, epsfcn, factor, diag)
    390 
    391     if n > m:
--> 392         raise TypeError('Improper input: N=%s must not exceed M=%s' % (n, m))
    393 
    394     if epsfcn is None:

TypeError: Improper input: N=3 must not exceed M=2

I am not sure how to resolve this error.

Upvotes: 0

Views: 113

Answers (1)

Muhammad Junaid Haris
Muhammad Junaid Haris

Reputation: 452

Generally, N is the number of features and M is the number of example points(rows). So, I believe because your number of rows is less than the number of features the model is having difficulty in fitting the curve.

Upvotes: 1

Related Questions