user7468395
user7468395

Reputation: 1369

How to find inverse of fast sigmoid function using sympy?

I would like to solve the following equation (inspired by: Fast sigmoid algorithm) for the variable x:

0 = lower + (upper - lower) * (0.5 + 0.5 * x / (1 + abs(x))) - y

If I use for this sympy, I get an error:

from sympy.solvers import solve
from sympy import Symbol
x = Symbol('x', real=True)
y = Symbol('y', real=True)
lower = Symbol('lower', real=True)
upper = Symbol('upper', real=True)

solve(lower + (upper - lower) * (0.5 + 0.5 * x / (1 + abs(x))) -y, x)

error:

  File "/home/user/venv/numba/lib/python3.6/site-packages/sympy/core/function.py", line 3082, in nfloat
    return type(expr)([nfloat(a, n, exponent) for a in expr])
  File "/home/user/venv/numba/lib/python3.6/site-packages/sympy/core/function.py", line 3082, in <listcomp>
    return type(expr)([nfloat(a, n, exponent) for a in expr])
  File "/home/user/venv/numba/lib/python3.6/site-packages/sympy/core/function.py", line 3082, in nfloat
    return type(expr)([nfloat(a, n, exponent) for a in expr])
TypeError: __new__() missing 1 required positional argument: 'cond'

How could I solve this equation with sympy?

(or if somebody is able to solve the equation for x manually: how would the inversion of the function look like anyhow?)

Upvotes: 2

Views: 1108

Answers (2)

Mike Amy
Mike Amy

Reputation: 391

btw, I hand-derived a python function for the plain fast sigmoid:

from math import copysign

def inverse_fast_sigmoid(x):
    assert -1.0 < x < 1.0
    return copysign(
        1 / (
            1 - abs(x)
        ) - 1, 
        x
    )

Maybe you can adjust it for your version.

Upvotes: 1

Oscar Benjamin
Oscar Benjamin

Reputation: 14540

This seems to be a bug in SymPy version 1.4. On master I don't get the exception and instead I get:

In [2]: solve(lower + (upper - lower) * (0.5 + 0.5 * x / (1 + abs(x))) -y, x)                                                                                                     
Out[2]: 
⎡⎧0.5⋅lower + 0.5⋅upper - y      0.5⋅(lower + upper - 2.0⋅y)      ⎧-0.5⋅lower - 0.5⋅upper + y      0.5⋅(-lower - upper + 2.0⋅y)    ⎤
⎢⎪─────────────────────────  for ─────────────────────────── < 0  ⎪──────────────────────────  for ──────────────────────────── ≥ 0⎥
⎢⎨        lower - y                       lower - y             , ⎨        upper - y                        upper - y              ⎥
⎢⎪                                                                ⎪                                                                ⎥
⎣⎩           nan                          otherwise               ⎩           nan                           otherwise              ⎦

This returns two piecewise solutions corresponding to the case of negative and positive x (I think).

I'm not happy with the result above though. I think the proper result should be something like this:

In [46]: eqn = lower + (upper - lower) * (0.5 + 0.5 * x / (1 + abs(x))) - y                                                                                                       

In [47]: eqn = piecewise_fold(eqn.rewrite(Piecewise))                                                                                                                             

In [48]: eqn                                                                                                                                                                      
Out[48]: 
⎧                             ⎛0.5⋅x      ⎞           
⎪lower - y + (-lower + upper)⋅⎜───── + 0.5⎟  for x ≥ 0
⎪                             ⎝x + 1      ⎠           
⎨                                                     
⎪                             ⎛0.5⋅x      ⎞           
⎪lower - y + (-lower + upper)⋅⎜───── + 0.5⎟  otherwise
⎩                             ⎝1 - x      ⎠           

In [49]: sx1, = solve(eqn.args[0][0], x)                                                                                                                                          

In [50]: sx2, = solve(eqn.args[1][0], x)                                                                                                                                          

In [51]: cx1 = eqn.args[0][1].subs(x, sx1)                                                                                                                                        

In [52]: sol = Piecewise((sx1, cx1), (sx2, True))                                                                                                                                 

In [53]: sol                                                                                                                                                                      
Out[53]: 
⎧-0.5⋅lower - 0.5⋅upper + y      -0.5⋅lower - 0.5⋅upper + y    
⎪──────────────────────────  for ────────────────────────── ≥ 0
⎪        upper - y                       upper - y             
⎨                                                              
⎪0.5⋅lower + 0.5⋅upper - y                                     
⎪─────────────────────────               otherwise             
⎩        lower - y 

Upvotes: 3

Related Questions