Reputation: 51
The ODE solvers in Python take the RHS of the function as an arguemnt. This function is supposed to signature f(t, y *args)
. In general, y
can be a numpy array and f
will return an array ydot
with same size as that of y
. The syntax, however, demands that f
should only return one thing and only one thing: the ydot
array (or float is y is also a float). The ODE solver internally calls f
as many times as it desires in order to reach convergence.
However in my case, this f
is returning a tuple. Not just ydot
, but a few more arrays. Now I can move this part outside f
and call it after solver has reached convergence. But then I have changed the behavior, as the those parameters are not being updated at every time-step calculation (which should happen). So, I need to have that part of code there. But the syntax for the solvers like scipy.integrate_solve_ivp
demands that f
should return only ydot
. So, I have no way to 'catch' the extra parameters and pass back to f
when the solver calls it again.
Here below is a code which gives the general idea (note this snippet is just a stand-in for my actual code, which I cannot share, but the gist is same, my RHS of the function returns a tuple):
import numpy as np
from scipy.integrate import solve_ivp
# Define the RHS function
def f(t, y, param1, param2):
# Calculate the derivative dy/dt
ydot = np.sin(t) * y[0] + np.cos(t) * y[1] # Example derivative equation
# Calculate additional parameters
param1 = some_func_1(ydot, param1)
param2 = some_func_1(ydot, param2)
# Return ydot and additional parameters as a tuple
return ydot, param1, param2
# Define the time span and initial condition
t_span = (0, 5)
y0 = np.random.rand(2) # Random initial condition
# Solve the ODE using solve_ivp
sol = solve_ivp(f, t_span, y0, args=(param1, param2))
some_func_1
and some_func_2
are some functions which take the param1 and param2 and do some modifications on it.
I tried looking into the events
argument of solve_ivp
, but it did not work. I guess it is for something else.
But this seems like a problem that someone must have encountered before, right? One workaround I can see is use of global variables, where I don't return the extra parameters instead store them in a global list or a global variable or something. But I think global variables are dangerous as they can introduce bugs. So I am looking for a more approach.
EDIT:
I just realised something: is it possible to be args
of solve_ivp
take functions itself? So the expression becomes:
sol = solve_ivp(f, t_span, y0, args=(some_func_1(param1), some_func_2(param2))
But even here, the problem of catching the return by some_func_1
remains.
Upvotes: 0
Views: 135
Reputation: 9046
Your ODE appears to be unaffected by the values of param1
and param2
. So, I recommend removing them from the ODE and simply computing them afterward using the solution you get from solve_ivp
.
import numpy as np
from scipy.integrate import solve_ivp
def f(t, y):
return np.sin(t) * y[0] + np.cos(t) * y[1]
def compute_params(t, y):
raise NotImplementedError
t_span = (0, 5)
y0 = np.random.default_rng(42).random(2)
sol = solve_ivp(f, t_span, y0)
param1, param2 = compute_params(sol.t, sol.y)
Upvotes: 0
Reputation: 231615
With care it is possible to pass args
objects to the function, and modify them. The objects need to be mutable, such as arrays or lists.
For example, taking the expoential_decay
example, let's add a means of collecting all t
values. I'l define par
as a list, and use append
to modify it in-place. I could have done something with an array, but this was the easiest thing to think of.
In [316]: def exponential_decay(t, y): return -0.5 * y
...: def f(t,y,par):
...: val = exponential_decay(t,y)
...: par.append(t)
...: return val
...: par = []
...: sol = integrate.solve_ivp(f, [0, 10], [2, 4, 8], args=(par,))
In [317]: sol
Out[317]:
message: The solver successfully reached the end of the integration interval.
success: True
status: 0
t: [ 0.000e+00 1.149e-01 1.264e+00 3.061e+00 4.816e+00
6.574e+00 8.333e+00 1.000e+01]
y: [[ 2.000e+00 1.888e+00 ... 3.107e-02 1.351e-02]
[ 4.000e+00 3.777e+00 ... 6.214e-02 2.702e-02]
[ 8.000e+00 7.553e+00 ... 1.243e-01 5.403e-02]]
sol: None
t_events: None
y_events: None
nfev: 44
njev: 0
nlu: 0
In [318]: len(par)
Out[318]: 44
In [319]: np.array(par)
Out[319]:
array([ 0. , 0.02 , 0.02297531, 0.03446296, 0.09190123,
0.10211248, 0.11487653, 0.11487653, 0.3446296 , 0.45950614,
1.03388881, 1.13600129, 1.26364188, 1.26364188, 1.62303707,
1.80273466, 2.70122262, 2.86095382, 3.06061781, 3.06061781,
3.41171646, 3.58726578, 4.4650124 , 4.62105624, 4.81611105,
4.81611105, 5.16778045, 5.34361515, 6.22278865, 6.37908617,
6.57445806, 6.57445806, 6.92622442, 7.1021076 , 7.98152352,
8.13786412, 8.33328988, 8.33328988, 8.66663191, 8.83330292,
9.66665798, 9.81480999, 10. , 10. ])
Upvotes: 0