JakobJakobson13
JakobJakobson13

Reputation: 145

How to write a flexible multiple exponential fit

I'd like to write a more or less universial fit function for general function

$f_i = \sum_i a_i exp(-t/tau_i)$

for some data I have.

Below is an example code for a biexponential function but I would like to be able to fit a monoexponential or a triexponential function with the smallest code adaptions possible.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import numpy as np
from scipy.optimize import curve_fit

import matplotlib.pyplot as plt

t = np.linspace(0, 10, 100)

a_1 = 1
a_2 = 1
tau_1 = 5
tau_2 = 1

data = 1*np.exp(-t/5) + 1*np.exp(-t/1)
data += 0.2 * np.random.normal(size=t.size)

def func(t, a_1, tau_1, a_2, tau_2): # plus more exponential functions
    return a_1*np.exp(-t/tau_1)+a_2*np.exp(-t/tau_2)

popt, pcov = curve_fit(func, t, data)
print(popt)
plt.plot(t, data, label="data")
plt.plot(t, func(t, *popt), label="fit")
plt.legend()
plt.show()

In principle I thought of redefining the function to a general form

def func(t, a, tau): # with a and tau as a list
    tmp = 0
    tmp += a[i]*np.exp(-t/tau[i])
    return tmp

and passing the arguments to curve_fit in the form of lists or tuples. However I get a TypeError as shown below.

 TypeError: func() takes 4 positional arguments but 7 were given

Is there anyway to rewrite the code that you can only by the input parameters of curve_fit "determine" the degree of the multiexponential function? So that passing

a = (1)

results in a monoexponential function whereas passing

a = (1, 2, 3)

results in a triexponential function?

Regards

Upvotes: 2

Views: 1037

Answers (1)

Brenlla
Brenlla

Reputation: 1481

Yes, that can be done easily with np.broadcasting:

def func(t, a, taus): # plus more exponential functions
    a=np.array(a)[:,None]
    taus=np.array(taus)[:,None]
    return (a*np.exp(-t/taus)).sum(axis=0)

func accepts 2 lists, converts them into 2-dim np.array, computes a matrix with all the exponentials and then sums it up. Example:

t=np.arange(100).astype(float)
out=func(t,[1,2],[0.3,4])
plt.plot(out)

Keep in mind a and taus must be the same length, so sanitize your inputs as you see fit. Or you could also directly pass np.arrays instead of lists.

Upvotes: 3

Related Questions