Nico Schlömer
Nico Schlömer

Reputation: 58871

SymPy lambdify with dot()

Take an undefined function that happens to be named dot, and make it part of lambdify:

import numpy
import sympy


class dot(sympy.Function):
    pass

x = sympy.Symbol('x')

a = sympy.Matrix([1, 0, 0])
f = sympy.lambdify(x, dot(a.T, x))

x = numpy.array([3, 2, 1])
print(f(x))

Surprise: This actually works!

Apparently, the string "dot" is somehow extracted and replaced by an implementation of the dot-product. Does anyone know which?

The result of the above is [3]. I would, however, like to get the scalar 3. (How) can I modify f() to achieve that?

Upvotes: 2

Views: 382

Answers (1)

Bakuriu
Bakuriu

Reputation: 102029

I'm not a sympy user however quoting the documentation for lambdify it says:

If not specified differently by the user, SymPy functions are replaced as far as possible by either python-math, numpy (if available) or mpmath functions - exactly in this order. To change this behavior, the “modules” argument can be used. It accepts:

  • the strings “math”, “mpmath”, “numpy”, “numexpr”, “sympy”
  • any modules (e.g. math)
  • dictionaries that map names of sympy functions to arbitrary functions
  • lists that contain a mix of the arguments above, with higher priority given to entries appearing first.

So it seems that if you have python-math installed it will use that, if not but you have numpy installed it will use numpy's version, otherwise mpmat and then describes how to modify this behaviour.

In your case just provide a modules value that is a dictionary that maps the name dot to a function that return a scalar as you want.


An example of what I mean:

>>> import numpy as np
>>> import sympy
>>> class dot(sympy.Function): pass
... 
>>> x = sympy.Symbol('x')
>>> a = sympy.Matrix([1,0,0])
>>> f = sympy.lambdify(x, dot(a.T, x), modules=[{'dot': lambda x, y: np.dot(x, y)[0]}, 'numpy'])
>>> y = np.array([3,2,1])
>>> print(f(y))
3
>>> print(type(f(y)))
<class 'numpy.int64'>

As you can see by manipulating the modules argument you can achieve what you want. My implementation here is absolutely naive, but you can generalize it like:

>>> def my_dot(x, y):
...     res = np.dot(x, y)
...     if res.ndim == 1 and res.size == 1:
...         return res[0]
...     return res

This function checks whether the result of the normal dot is a scalar, and if so returns the plain scalar and otherwise return the same result as np.dot.

Upvotes: 2

Related Questions