Reputation: 21
I have a problem with the curve_fit
function of scipy.optimize
. I have a rather complex class which can calculate the propagation of light and I want it to be fitted to some measurements. For this thread, I will give a simple example of the structure of this class
import numpy as np
from scipy.optimize import curve_fit
import matplotlib.pyplot as plt
class Gauss():
def __init__(self,x,parDict={}):
self.default_dict = {"Amplitude": 1, "Center": 0, "FWHM":1}
self.x = x
if parDict:
self.parDict = self.fillDict(parDict)
else:
self.parDict = self.default_dict
def fillDict(self,dictionary):
return_dict={}
if "Amplitude" in dictionary.keys():
return_dict["Amplitude"] = dictionary["Amplitude"]
else:
return_dict["Amplitude"] = self.default_dict["Amplitude"]
if "Center" in dictionary.keys():
return_dict["Center"] = dictionary["Center"]
else:
return_dict["Center"] = self.default_dict["Center"]
if "FWHM" in dictionary.keys():
return_dict["FWHM"] = dictionary["FWHM"]
else:
return_dict["FWHM"] = self.default_dict["FWHM"]
return return_dict
def calculate(self):
return self.parDict["Amplitude"]*np.exp(-((self.x-self.parDict["Center"])**2/(self.parDict["FWHM"]**2*np.sqrt(2))))
x = np.linspace(-5,5,100)
g = Gauss(x)
y = g.calculate()
plt.plot(x,y)
plt.show()
So one can see, I have all the parameters inside a dictionary for simpler use and a default parameter dictionary for all the missing parameters I don't want to change. Now I am struggling to find a clever way or even any way to write a function or class which can use this class to fit it to some measurements. Like:
ynoise = y + np.random.normal(loc=0.2, scale=0.2, size=len(x))
param = {"Amplitude":1.2,"FWHM":0.5}
fit_par,_ = curve_fit(g.calculate, x, ynoise , p0=param)
The problem is, that this class is part of a much larger programming project and inside a program so I don't have the meaning of manipulating the class itself. Can one of you think of a solution?
Upvotes: 1
Views: 219
Reputation: 21
With the comment by Bob with his answer to this post, I could create a solution for my problem.
For everyone who want the solution:
def fitFunc(x, dictionary):
g = Gauss(x,dictionary)
return g.calculate()
fit_par,_ = curve_fit(lambda x,*args: fitFunc(x, {k:v for k,v in zip(param.keys(),args)}),x, ynoise, [param[k] for k in param.keys()])
fit_par = {k:v for k,v in zip(param.keys(), fit_par)}
Thank you very much!
Upvotes: 1