pythonjam
pythonjam

Reputation: 13

python sympy - how to process expressions in order to protect them from specific evaluations

sympy seems to evaluate expressions by default which is problematic in scenarios where automatic evaluation negatively impacts numerical stability. I need a way to control what gets evaluated to preserve that stability.

The only official mechanism I'm aware of is the UnevaluatedExpr class, but this solution is problematic for my purpose. Users of my code are not supposed to be burdened by any numerical stability considerations. They simply want to enter an expression and the code needs to do all the rest. Making them analyze the numerical stability of their own expressions is not an option. It needs to be done automatically.

First I tried to gain control of sympify() by monkeypatching it as it seems the main culprit behind most calls that lead to unwanted evaluation, but I only came as far as catchig all the calls, without being able to really control them the way I wanted. I bumped against so many walls there that I wouldn't even know where to start.

Modifying sympy itself, as you can probably imagine, is not an option either as I can't possibly require users to make some exotic patches of their local sympy installations.

Next I discovered that it's possible to say

with evaluate(False): doSomeStuffToExpression(expr)

This seems to violently shove evaluate=False down the throat of sympy no matter what. However that means it radically deactivates all evaluation and does not allow any fine control.

Specifically I want to deactivate evaluation when there is an Add inside an sympy.exp

So the third attempt was to modify the expression tree. Basically developing a method that takes the expression, traverses it and automatically wraps args with UnevaluatedExpr where needed (remember: I can't bother the user with doing that manually)

So I wrote the following code to test the new apporach:

from sympy.core.expr import UnevaluatedExpr
from sympy.core.symbol import Symbol
import sympy as sp
from sympy.core.numbers import Float

x, z = sp.symbols('x z')

#expr = (x + 2.*x)/4. + sp.exp((x+sp.UnevaluatedExpr(32.))/6.)
expr = sp.sympify('(x + 2.*x)/4. + exp((x+32.)/6.)', evaluate=False)

expr_ = expr.subs(x, z)

print(expr)
print(expr_)
print('///////////\n')

def prep(expr, exp_depth = 0):
    
    # once we are inside UnevaluatedExpr, we need to continue to traverse
    # down to the Symbol and also wrap it with UnevaluatedExpr
    if isinstance(expr, UnevaluatedExpr): 
        for arg in expr.args:  
            newargs = []
            for arg_inside in arg.args:      
                if isinstance(arg_inside, Symbol) or isinstance(arg_inside, Float):
                    newargs.append(UnevaluatedExpr(arg_inside))
                else:
                    newargs.append(arg_inside)
                    
            arg._args = tuple(newargs)
            for arg_inside in arg.args:       
                prep(arg_inside, exp_depth = exp_depth + 1)        
        return
    
    original_args = expr.args
    # if args empty
    if not original_args: return
    
    # check if we just entered exp
    is_exp = (expr.func == sp.exp)

    print('\n-----')
    print('expression\t\t-->', expr)
    print('func || args\t\t-->', expr.func, ' || ', original_args)
    print('is it exp right now?\t-->', is_exp)
    print('inside exp?\t-->', exp_depth > 0)
    
    # if we just received exp or if we are inside exp
    if is_exp or exp_depth > 0:
        newargs = []
        for arg in original_args:
            if isinstance(arg, sp.Add):
                newargs.append(UnevaluatedExpr(arg))
            else:
                newargs.append(arg)
        
        expr._args= tuple(newargs)
          
        for arg in expr.args:       
                prep(arg, exp_depth = exp_depth + 1)
    else:
        for arg in original_args: prep(arg, exp_depth)

prep(expr)

print('///////////\n')
print(expr)

substituted = expr.subs(x, z)
print("substitution after prep still does not work:\n", substituted)

wewantthis = expr.subs(x, UnevaluatedExpr(z))
print("we want:\n", wewantthis)

print('///////////\n')

However the output was dissapointing as subs() triggers the dreaded evaluation again, despite wrapping args in UnevaluatedExpr where needed. Or let's say where I understood wrapping would be needed.

For some reason subs() completely ignores my changes.

So the question is: is there even any hope in this last approach (maybe I still missed something when traversing the tree) - and if there is no hope in my approach, how else should I achieve the goal of disabling evaluation of a specific Symbol when encountering an Add inside an sympy.exp (which is the exponential function)

PS:

I should probably also mention that for reasons that seem puzzling, the following works (but as I mentioned it's a manual solution that I don't desire)

expr = (x + 2.*x)/4. + sp.exp((x+sp.UnevaluatedExpr(32.))/6.)
expr_ = expr.subs(x, z)
print(expr)
print(expr_)

Here we successfully prevented the evaluation of the Add inside sp.exp

Output:

0.75*x + exp(0.166666666666667*(x + 32.0))
0.75*z + exp(0.166666666666667*(z + 32.0))

Edit 0:

  1. The solution should permit the usage of floats. For example some of the values may describe physical properties, measured beyond the accuracy of an integer. I need to be able to allow those.

  2. Substituting Floats with Symbols is also problematic as it substantially complicates handling of the expressions or the usage of those expressions at a later time

Upvotes: 1

Views: 762

Answers (1)

Oscar Benjamin
Oscar Benjamin

Reputation: 14500

I'm not sure but I think that the problem you are having is to do with automatic distribution of a Number over an Add which is controlled by the distribute context manager:

In [326]: e1 = 2*(x + 1)                                                                                                          

In [327]: e1                                                                                                                      
Out[327]: 2⋅x + 2

In [328]: from sympy.core.parameters import distribute                                                                            

In [329]: with distribute(False): 
     ...:     e2 = 2*(x + 1) 
     ...:                                                                                                                         

In [330]: e2                                                                                                                      
Out[330]: 2⋅(x + 1)

The automatic distribution behaviour is something that would be good to change in sympy. It's just not easy to change because it is such a low-level operation and it has been this way for a long time (it would break a lot of people's code).

Other parts of the evaluation that you see are specific to the fact that you are using floats and would not happen for Rational or for a symbol e.g.:

In [337]: exp(2*(x + 1))                                                                                                          
Out[337]: 
 2⋅x + 2
ℯ       

In [338]: exp(2.0*(x + 1))                                                                                                        
Out[338]: 
                  2.0⋅x
7.38905609893065⋅ℯ     

In [339]: exp(y*(x + 1))                                                                                                          
Out[339]: 
 y⋅(x + 1)
ℯ  

You could convert rationals to float with nsimplify to avoid that e.g.:

In [340]: parse_expr('exp(2.0*(x + 1))', evaluate=False)                                                                          
Out[340]: 
 2.0⋅(x + 1)
ℯ           

In [341]: parse_expr('exp(2.0*(x + 1))', evaluate=False).subs(x, z)                                                               
Out[341]: 
                  2.0⋅z
7.38905609893065⋅ℯ     

In [342]: nsimplify(parse_expr('exp(2.0*(x + 1))', evaluate=False))                                                               
Out[342]: 
 2⋅x + 2
ℯ       

In [343]: nsimplify(parse_expr('exp(2.0*(x + 1))', evaluate=False)).subs(x, z)                                                    
Out[343]: 
 2⋅z + 2
ℯ  

Another possibility is to use symbols and delay substitution of any values until numerical evaluation. This is the way to get the most accurate result from evalf:

In [344]: exp(y*(z + 1)).evalf(subs={y:1, z:2})                                                                                   
Out[344]: 20.0855369231877

Upvotes: 2

Related Questions