Reputation: 455
from scipy.optimize import curve_fit
def func(x, a, b):
return a * np.exp(-b * x)
xdata = np.linspace(0, 4, 50)
ydata = np.linspace(0, 4, 50)
popt, pcov = curve_fit(func, xdata, ydata)
It is quite easy to fit an arbitrary Gaussian in python with something like the above method. However, I would like to prepare a function that always the user to select an arbitrary number of Gaussians and still attempt to find a best fit.
I'm trying to figure out how to modify the function func
so that I can pass it an additional parameter n=2
for instance and it would return a function that would try and fit 2 Gaussians, akin to:
from scipy.optimize import curve_fit
def func2(x, a, b, d, e):
return (a * np.exp(-b * x) + c * np.exp(-d * x))
xdata = np.linspace(0, 4, 50)
ydata = np.linspace(0, 4, 50)
popt, pcov = curve_fit(func2, xdata, ydata)
Without having to hard code this additional cases, so that we could instead pass a single function like func(...,n=2)
and get the same result as above. I'm having trouble finding an elegant solution to this. My guess is that the best solution will be something using a lambda
function.
Upvotes: 2
Views: 1360
Reputation: 23647
You can define the function to take a variable number of arguments using def func(x, *args)
. The *args
variable then contains something like (a1, b1, a2, b2, a3, ...)
. It would be possible to loop over those and sum the Gaussians but I'm showing a vectorized solution instead.
Since curve_fit
can no longer determine the number of parameters from this function, you can provide an initial guess that determines the amount of Gaussians you want to fit. Each Gaussian requires two parameters, so [1, 1]*n
produces a parameter vector of the correct length.
from scipy.optimize import curve_fit
import numpy as np
def func(x, *args):
x = x.reshape(-1, 1)
a = np.array(args[0::2]).reshape(1, -1)
b = np.array(args[1::2]).reshape(1, -1)
return np.sum(a * np.exp(-b * x), axis=1)
n = 3
xdata = np.linspace(0, 4, 50)
ydata = np.linspace(0, 4, 50)
popt, pcov = curve_fit(func, xdata, ydata, p0=[1, 1] * n)
Upvotes: 3