Reputation: 145
I'd like to write a more or less universial fit function for general function
$f_i = \sum_i a_i exp(-t/tau_i)$
for some data I have.
Below is an example code for a biexponential function but I would like to be able to fit a monoexponential or a triexponential function with the smallest code adaptions possible.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import numpy as np
from scipy.optimize import curve_fit
import matplotlib.pyplot as plt
t = np.linspace(0, 10, 100)
a_1 = 1
a_2 = 1
tau_1 = 5
tau_2 = 1
data = 1*np.exp(-t/5) + 1*np.exp(-t/1)
data += 0.2 * np.random.normal(size=t.size)
def func(t, a_1, tau_1, a_2, tau_2): # plus more exponential functions
return a_1*np.exp(-t/tau_1)+a_2*np.exp(-t/tau_2)
popt, pcov = curve_fit(func, t, data)
print(popt)
plt.plot(t, data, label="data")
plt.plot(t, func(t, *popt), label="fit")
plt.legend()
plt.show()
In principle I thought of redefining the function to a general form
def func(t, a, tau): # with a and tau as a list
tmp = 0
tmp += a[i]*np.exp(-t/tau[i])
return tmp
and passing the arguments to curve_fit in the form of lists or tuples. However I get a TypeError as shown below.
TypeError: func() takes 4 positional arguments but 7 were given
Is there anyway to rewrite the code that you can only by the input parameters of curve_fit "determine" the degree of the multiexponential function? So that passing
a = (1)
results in a monoexponential function whereas passing
a = (1, 2, 3)
results in a triexponential function?
Regards
Upvotes: 2
Views: 1037
Reputation: 1481
Yes, that can be done easily with np.broadcasting
:
def func(t, a, taus): # plus more exponential functions
a=np.array(a)[:,None]
taus=np.array(taus)[:,None]
return (a*np.exp(-t/taus)).sum(axis=0)
func
accepts 2 lists, converts them into 2-dim np.array, computes a matrix with all the exponentials and then sums it up. Example:
t=np.arange(100).astype(float)
out=func(t,[1,2],[0.3,4])
plt.plot(out)
Keep in mind a
and taus
must be the same length, so sanitize your inputs as you see fit. Or you could also directly pass np.arrays instead of lists.
Upvotes: 3