akozi
akozi

Reputation: 455

Most pythonic way to fit multiple gaussians using scipy.optimize

from scipy.optimize import curve_fit

def func(x, a, b):
    return a * np.exp(-b * x)

xdata = np.linspace(0, 4, 50)
ydata = np.linspace(0, 4, 50)
popt, pcov = curve_fit(func, xdata, ydata)

It is quite easy to fit an arbitrary Gaussian in python with something like the above method. However, I would like to prepare a function that always the user to select an arbitrary number of Gaussians and still attempt to find a best fit.

I'm trying to figure out how to modify the function func so that I can pass it an additional parameter n=2 for instance and it would return a function that would try and fit 2 Gaussians, akin to:

from scipy.optimize import curve_fit

def func2(x, a, b, d, e):
    return (a * np.exp(-b * x) + c * np.exp(-d * x))  

xdata = np.linspace(0, 4, 50)
ydata = np.linspace(0, 4, 50)
popt, pcov = curve_fit(func2, xdata, ydata)

Without having to hard code this additional cases, so that we could instead pass a single function like func(...,n=2) and get the same result as above. I'm having trouble finding an elegant solution to this. My guess is that the best solution will be something using a lambda function.

Upvotes: 2

Views: 1360

Answers (1)

MB-F
MB-F

Reputation: 23647

You can define the function to take a variable number of arguments using def func(x, *args). The *args variable then contains something like (a1, b1, a2, b2, a3, ...). It would be possible to loop over those and sum the Gaussians but I'm showing a vectorized solution instead.

Since curve_fit can no longer determine the number of parameters from this function, you can provide an initial guess that determines the amount of Gaussians you want to fit. Each Gaussian requires two parameters, so [1, 1]*n produces a parameter vector of the correct length.

from scipy.optimize import curve_fit
import numpy as np

def func(x, *args):
    x = x.reshape(-1, 1)
    a = np.array(args[0::2]).reshape(1, -1)
    b = np.array(args[1::2]).reshape(1, -1)
    return np.sum(a * np.exp(-b * x), axis=1)

n = 3

xdata = np.linspace(0, 4, 50)
ydata = np.linspace(0, 4, 50)
popt, pcov = curve_fit(func, xdata, ydata, p0=[1, 1] * n)

Upvotes: 3

Related Questions