Reputation: 12026
I've got some big datasets to which I'd like to fit monoexponential time decays.
The data consists of multiple 4D datasets, acquired at different times, and the fit should thus run along a 5th dimension (through datasets).
The code I'm currently using is the following:
import numpy as np
import scipy.optimize as opt
[... load 4D datasets ....]
data = (dataset1, dataset2, dataset3)
times = (10, 20, 30)
def monoexponential(t, M0, t_const):
return M0*np.exp(-t/t_const)
# Starting guesses to initiate descent.
M0_init = 80.0
t_const_init = 50.0
init_guess = (M0_init, t_const_init)
def fit(vector):
try:
nlfit, nlpcov = opt.curve_fit(monoexponential, times, vector,
p0=init_guess,
sigma=None,
check_finite=False,
maxfev=100, ftol=0.5, xtol=1,
bounds=([0, 2000], [0, 800]))
M0, t_const = nlfit
except:
t_const = 0
return t_const
# Concatenate datasets in data into a single 5D array.
concat5D = np.concatenate([block[..., np.newaxis] for block in data],
axis=len(data[0].shape))
# And apply the curve fitting along the last dimension.
decay_map = np.apply_along_axis(fit, len(concat5D.shape) - 1, concat5D)
The code works fine, but takes forever (e.g, for dataset1.shape == (100,100,50,500)
). I've read some other topics mentioning that apply_along_axis
is very slow, so I'm guessing that's the culprit. Unfortunately, I don't really know what could be used as an alternative here (except maybe an explicit for loop?).
Does anyone have an idea of what I can do to avoid apply_along_axis
and speed up curve_fit being called multiple times?
Upvotes: 2
Views: 1352
Reputation: 26050
In this particular case, where you're fitting a single exponential, you're likely better off to take the log of your data. Then fitting becomes linear and that is much faster than a nonlinear least squares, and can likely be vectorized since it becomes pretty much a linear algebra problem.
(And of course, if you have an idea of how to improve least_squares
, that might be appreciated by the scipy devs.)
Upvotes: 0
Reputation: 231395
So you are applying a fit
operation 100*100*50*500 times, to a 1d array (of 3 values in the example, more in real life?)?
apply_along_axis
does iterate over all the dimensions of the input array, except for one. There's no compiling or doing this fit
over multiple axes at once.
Without apply_along_axis
the easiest approach is to reshape the array into a 2d one, compressing (100,100,50,500) to one (250...,) dimension, and then iterating on that. And then reshaping the result.
I was thinking that concatenating the datasets
on a last axis might be slower than doing so on the first, but timings suggest otherwise.
np.stack
is a new version of concatenate
that makes it easy to add the new axis any where.
In [319]: x=np.ones((2,3,4,5),int)
In [320]: d=[x,x,x,x,x,x]
In [321]: np.stack(d,axis=0).shape # same as np.array(d)
Out[321]: (6, 2, 3, 4, 5)
In [322]: np.stack(d,axis=-1).shape
Out[322]: (2, 3, 4, 5, 6)
for a larger list (with a trivial sum
function):
In [295]: d1=[x]*1000 # make a big list
In [296]: timeit np.apply_along_axis(sum,-1,np.stack(d1,-1)).shape
10 loops, best of 3: 39.7 ms per loop
In [297]: timeit np.apply_along_axis(sum,0,np.stack(d1,0)).shape
10 loops, best of 3: 39.2 ms per loop
an explicit loop using array reshape times about the same
In [312]: %%timeit
.....: d2=np.stack(d1,-1)
.....: d2=d2.reshape(-1,1000)
.....: res=np.stack([sum(i) for i in d2],0).reshape(d1[0].shape)
.....:
10 loops, best of 3: 39.1 ms per loop
But a function like sum
can work on whole array, and do so much faster
In [315]: timeit np.stack(d1,-1).sum(-1).shape
100 loops, best of 3: 3.52 ms per loop
So changing the stacking and iteration methods doesn't make much difference in speed. But changing the 'fit' so it can work over more than one dimension can be a big help. I don't know enough of optimize.fit
to know if that is possible.
====================
I just dug into the code for apply_along_axis
. It basically constructs an index that looks like ind=(0,1,slice(None),2,1)
, and does func(arr[ind])
, and then increments it, sort like long arithmetic with carry. So it is just systematically stepping through all elements, while keeping one axis a :
slice.
Upvotes: 2