peakcipher
peakcipher

Reputation: 51

Handling multiple return parameters from RHS of function given to an ODE solver

The ODE solvers in Python take the RHS of the function as an arguemnt. This function is supposed to signature f(t, y *args). In general, y can be a numpy array and f will return an array ydot with same size as that of y. The syntax, however, demands that f should only return one thing and only one thing: the ydot array (or float is y is also a float). The ODE solver internally calls f as many times as it desires in order to reach convergence.

However in my case, this f is returning a tuple. Not just ydot, but a few more arrays. Now I can move this part outside f and call it after solver has reached convergence. But then I have changed the behavior, as the those parameters are not being updated at every time-step calculation (which should happen). So, I need to have that part of code there. But the syntax for the solvers like scipy.integrate_solve_ivp demands that f should return only ydot. So, I have no way to 'catch' the extra parameters and pass back to f when the solver calls it again.

Here below is a code which gives the general idea (note this snippet is just a stand-in for my actual code, which I cannot share, but the gist is same, my RHS of the function returns a tuple):

import numpy as np
from scipy.integrate import solve_ivp

# Define the RHS function
def f(t, y, param1, param2):
    # Calculate the derivative dy/dt
    ydot = np.sin(t) * y[0] + np.cos(t) * y[1]  # Example derivative equation

    # Calculate additional parameters
    param1 = some_func_1(ydot, param1)
    param2 = some_func_1(ydot, param2)

    # Return ydot and additional parameters as a tuple
    return ydot, param1, param2

# Define the time span and initial condition
t_span = (0, 5)
y0 = np.random.rand(2)  # Random initial condition

# Solve the ODE using solve_ivp
sol = solve_ivp(f, t_span, y0, args=(param1, param2))

some_func_1 and some_func_2 are some functions which take the param1 and param2 and do some modifications on it.

I tried looking into the events argument of solve_ivp, but it did not work. I guess it is for something else.

But this seems like a problem that someone must have encountered before, right? One workaround I can see is use of global variables, where I don't return the extra parameters instead store them in a global list or a global variable or something. But I think global variables are dangerous as they can introduce bugs. So I am looking for a more approach.

EDIT:

I just realised something: is it possible to be args of solve_ivp take functions itself? So the expression becomes:

sol = solve_ivp(f, t_span, y0, args=(some_func_1(param1), some_func_2(param2))

But even here, the problem of catching the return by some_func_1 remains.

Upvotes: 0

Views: 135

Answers (2)

jared
jared

Reputation: 9046

Your ODE appears to be unaffected by the values of param1 and param2. So, I recommend removing them from the ODE and simply computing them afterward using the solution you get from solve_ivp.

import numpy as np
from scipy.integrate import solve_ivp

def f(t, y):
    return np.sin(t) * y[0] + np.cos(t) * y[1]

def compute_params(t, y):
    raise NotImplementedError

t_span = (0, 5)
y0 = np.random.default_rng(42).random(2)
sol = solve_ivp(f, t_span, y0)

param1, param2 = compute_params(sol.t, sol.y)

Upvotes: 0

hpaulj
hpaulj

Reputation: 231615

With care it is possible to pass args objects to the function, and modify them. The objects need to be mutable, such as arrays or lists.

For example, taking the expoential_decay example, let's add a means of collecting all t values. I'l define par as a list, and use append to modify it in-place. I could have done something with an array, but this was the easiest thing to think of.

In [316]: def exponential_decay(t, y): return -0.5 * y
     ...: def f(t,y,par):
     ...:     val = exponential_decay(t,y)
     ...:     par.append(t)
     ...:     return val
     ...: par = []
     ...: sol = integrate.solve_ivp(f, [0, 10], [2, 4, 8], args=(par,))
In [317]: sol
Out[317]: 
  message: The solver successfully reached the end of the integration interval.
  success: True
   status: 0
        t: [ 0.000e+00  1.149e-01  1.264e+00  3.061e+00  4.816e+00
             6.574e+00  8.333e+00  1.000e+01]
        y: [[ 2.000e+00  1.888e+00 ...  3.107e-02  1.351e-02]
            [ 4.000e+00  3.777e+00 ...  6.214e-02  2.702e-02]
            [ 8.000e+00  7.553e+00 ...  1.243e-01  5.403e-02]]
      sol: None
 t_events: None
 y_events: None
     nfev: 44
     njev: 0
      nlu: 0
In [318]: len(par)
Out[318]: 44
In [319]: np.array(par)
Out[319]: 
array([ 0.        ,  0.02      ,  0.02297531,  0.03446296,  0.09190123,
        0.10211248,  0.11487653,  0.11487653,  0.3446296 ,  0.45950614,
        1.03388881,  1.13600129,  1.26364188,  1.26364188,  1.62303707,
        1.80273466,  2.70122262,  2.86095382,  3.06061781,  3.06061781,
        3.41171646,  3.58726578,  4.4650124 ,  4.62105624,  4.81611105,
        4.81611105,  5.16778045,  5.34361515,  6.22278865,  6.37908617,
        6.57445806,  6.57445806,  6.92622442,  7.1021076 ,  7.98152352,
        8.13786412,  8.33328988,  8.33328988,  8.66663191,  8.83330292,
        9.66665798,  9.81480999, 10.        , 10.        ])

Upvotes: 0

Related Questions