Reputation: 58871
Take an undefined function that happens to be named dot
, and make it part of lambdify
:
import numpy
import sympy
class dot(sympy.Function):
pass
x = sympy.Symbol('x')
a = sympy.Matrix([1, 0, 0])
f = sympy.lambdify(x, dot(a.T, x))
x = numpy.array([3, 2, 1])
print(f(x))
Surprise: This actually works!
Apparently, the string "dot"
is somehow extracted and replaced by an implementation of the dot-product. Does anyone know which?
The result of the above is [3]
. I would, however, like to get the scalar 3
. (How) can I modify f()
to achieve that?
Upvotes: 2
Views: 382
Reputation: 102029
I'm not a sympy
user however quoting the documentation for lambdify
it says:
If not specified differently by the user, SymPy functions are replaced as far as possible by either
python-math
,numpy
(if available) ormpmath
functions - exactly in this order. To change this behavior, the “modules” argument can be used. It accepts:
- the strings “math”, “mpmath”, “numpy”, “numexpr”, “sympy”
- any modules (e.g. math)
- dictionaries that map names of sympy functions to arbitrary functions
- lists that contain a mix of the arguments above, with higher priority given to entries appearing first.
So it seems that if you have python-math
installed it will use that, if not but you have numpy
installed it will use numpy
's version, otherwise mpmat
and then describes how to modify this behaviour.
In your case just provide a modules
value that is a dictionary that maps the name dot
to a function that return a scalar as you want.
An example of what I mean:
>>> import numpy as np
>>> import sympy
>>> class dot(sympy.Function): pass
...
>>> x = sympy.Symbol('x')
>>> a = sympy.Matrix([1,0,0])
>>> f = sympy.lambdify(x, dot(a.T, x), modules=[{'dot': lambda x, y: np.dot(x, y)[0]}, 'numpy'])
>>> y = np.array([3,2,1])
>>> print(f(y))
3
>>> print(type(f(y)))
<class 'numpy.int64'>
As you can see by manipulating the modules
argument you can achieve what you want. My implementation here is absolutely naive, but you can generalize it like:
>>> def my_dot(x, y):
... res = np.dot(x, y)
... if res.ndim == 1 and res.size == 1:
... return res[0]
... return res
This function checks whether the result of the normal dot
is a scalar, and if so returns the plain scalar and otherwise return the same result as np.dot
.
Upvotes: 2