Reputation: 536
I have a standard gaussian function, that looks like this:
def gauss_fnc(x, amp, cen, sigma):
return amp * np.exp(-(x - cen) ** 2 / (2 * sigma ** 2))
And I have a fit_gaussian function that uses scipy's curve_fit to fit my gauss_fnc:
from scipy.optimize import curve_fit
def fit_gaussian(x, y):
mean = sum(x * y) / sum(y)
sigma = np.sqrt(sum(y * (x - mean) ** 2) / sum(y))
opt, cov = curve_fit(gauss_fnc, x, y, p0=[max(y), mean, sigma])
values = gauss_fnc(x, *opt)
return values, sigma, opt, cov
I can confirm that this works great if the data resembles a normal gaussian function, see example:
However if the signal is too peaked or too narrow, it won't work as expected. Example of a peaked gaussian:
Here is an example of a flat-top or super gaussian:
Currently the flatter the gaussian becomes, more and more information is lost, due to gaussian cutting down the edges. How can I improve the functions, or the curve fitting in order to be able to fit peaked and flat-top signals as well like in this picture:
Edit:
I provided a minimal working example to try this out:
from PyQt5.QtWidgets import (QApplication, QMainWindow)
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.figure import Figure
from scipy.optimize import curve_fit
import numpy as np
from PyQt5.QtWidgets import QWidget, QGridLayout
def gauss_fnc(x, amp, cen, sigma):
return amp * np.exp(-(x - cen) ** 2 / (2 * sigma ** 2))
def fit_gauss(x, y):
mean = sum(x * y) / sum(y)
sigma = np.sqrt(sum(y * (x - mean) ** 2) / sum(y))
opt, cov = curve_fit(gauss_fnc, x, y, p0=[max(y), mean, sigma])
vals = gauss_fnc(x, *opt)
return vals, sigma, opt, cov
class MainWindow(QMainWindow):
def __init__(self):
super().__init__()
self.results = list()
self.setWindowTitle('Gauss fitting')
self.setGeometry(50, 50, 1280, 1024)
self.setupLayout()
self.raw_data1 = np.array([1, 1, 1, 1, 3, 5, 7, 8, 9, 10, 11, 10, 9, 8, 7, 5, 3, 1, 1, 1, 1], dtype=int)
self.raw_data2 = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 200, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int)
self.raw_data3 = np.array([1, 1, 1, 1, 1, 3, 5, 9, 10, 10, 10, 10, 10, 9, 5, 3, 1, 1, 1, 1, 1], dtype=int)
self.plot()
def setupLayout(self):
# Create figures
self.fig1 = FigureCanvas(Figure(figsize=(5, 4), dpi=100))
self.fig1AX = self.fig1.figure.add_subplot(111, frameon=False)
self.fig1AX.get_xaxis().set_visible(True)
self.fig1AX.get_yaxis().set_visible(True)
self.fig1AX.yaxis.tick_right()
self.fig1AX.yaxis.set_label_position("right")
self.fig2 = FigureCanvas(Figure(figsize=(5, 4), dpi=100))
self.fig2AX = self.fig2.figure.add_subplot(111, frameon=False)
self.fig2AX.get_xaxis().set_visible(True)
self.fig2AX.get_yaxis().set_visible(True)
self.fig2AX.yaxis.tick_right()
self.fig2AX.yaxis.set_label_position("right")
self.fig3 = FigureCanvas(Figure(figsize=(5, 4), dpi=100))
self.fig3AX = self.fig3.figure.add_subplot(111, frameon=False)
self.fig3AX.get_xaxis().set_visible(True)
self.fig3AX.get_yaxis().set_visible(True)
self.fig3AX.yaxis.tick_right()
self.fig3AX.yaxis.set_label_position("right")
self.widget = QWidget(self)
grid = QGridLayout()
grid.addWidget(self.fig1, 0, 0, 1, 1)
grid.addWidget(self.fig2, 1, 0, 1, 1)
grid.addWidget(self.fig3, 2, 0, 1, 1)
self.widget.setLayout(grid)
self.setCentralWidget(self.widget)
def plot(self):
x = len(self.raw_data1)
xvals, sigma, optw, covar = fit_gauss(range(x), self.raw_data1)
self.fig1AX.clear()
self.fig1AX.plot(range(len(self.raw_data1)), self.raw_data1, 'k-')
self.fig1AX.plot(range(len(self.raw_data1)), xvals, 'b-', linewidth=2)
self.fig1AX.margins(0, 0)
self.fig1.figure.tight_layout()
self.fig1.draw()
xvals, sigma, optw, covar = fit_gauss(range(x), self.raw_data1)
self.fig2AX.clear()
self.fig2AX.plot(range(len(self.raw_data2)), self.raw_data2, 'k-')
self.fig2AX.plot(range(len(self.raw_data2)), xvals, 'b-', linewidth=2)
self.fig2AX.margins(0, 0)
self.fig2.figure.tight_layout()
self.fig2.draw()
self.fig3AX.clear()
self.fig3AX.plot(range(len(self.raw_data3)), self.raw_data3, 'k-')
self.fig3AX.plot(range(len(self.raw_data3)), xvals, 'b-', linewidth=2)
self.fig3AX.margins(0, 0)
self.fig3.figure.tight_layout()
self.fig3.draw()
if __name__ == '__main__':
app = QApplication([])
window = MainWindow()
window.show()
app.exec_()
Last picture is from here.
Upvotes: 3
Views: 2199
Reputation: 5740
You can use Gaussian defined function for curve fit:
import numpy as np
from matplotlib import pyplot as plt
from scipy.optimize import curve_fit
x = range(21)
y_peak = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 200, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int)
y_flat_top = np.array([1, 1, 1, 1, 1, 3, 5, 9, 10, 10, 10, 10, 10, 9, 5, 3, 1, 1, 1, 1, 1], dtype=int)
# define gauss function
def Gauss(x, a, x0, sigma):
return a * np.exp(-(x - x0)**2 / (2 * sigma**2))
# fit function
popt, pcov = curve_fit(Gauss, x, y_peak)
# set data for curve plot
x_fit = np.linspace(0,21,1000)
y_fit = Gauss(x_fit, popt[0], popt[1], popt[2])
y_fit = Gauss(x_fit, max(y_flat_top) , popt[1], popt[2])
# plot data
fig, ax = plt.subplots()
plt.plot(x, y_peak, '.')
plt.plot(x_fit, y_fit, '-', label='peak')
plt.legend()
plt.show()
Output:
With usage of Generalized normal distribution: It is hard to fit. You can play with bounds and try to add some additional parameters to function to get better fit. Another option is to use differential evolution algorithm to find best fit.
import matplotlib.pyplot as plt
import numpy as np
from scipy.optimize import curve_fit
# set data
x = np.linspace(-4, 4, 20)
y_peak = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 200, 2, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int)
y_flat_top = np.array([1, 1, 1, 1, 1, 3, 5, 9, 10, 10, 10, 10, 10, 9, 5, 3, 1, 1, 1, 1], dtype=int)
y = y_peak
# define generalized normal distribution
def general_norm(x, gamma, beta):
value = (beta/(2*gamma*(1/beta)))*np.exp(-np.abs(x)**beta)
return value
# set bounds
bounds_peak = ((0,0),(100,9))
bounds_flat_top = ((0,7),(100,9))
# fit function
popt, pcov = curve_fit(general_norm, x, y, bounds=bounds_peak)
# calculate rms
rms = sum((y - general_norm(x, popt[0], popt[1]))**2)
# set data for curve plot
x_fit = np.linspace(-4,4,1000)
y_fit = general_norm(x_fit, popt[0], popt[1])
# plot data
fig, ax = plt.subplots(1, 1)
ax.plot(x, y, '.')
ax.plot(x_fit, y_fit, 'b-')
plt.show()
Output:
Upvotes: 1