Frank-Rene Schäfer
Frank-Rene Schäfer

Reputation: 3352

sympy - symbolic sum over symbolic number of elements

What is the most appropriate way to express the following in SymPy:

A sum over samples 'x[i]' with 'i' going from concrete 0 to symbolic 'N'. 'x[i]' itself shall be symbolic, i.e. always appear as variable.

The goal is to use these expressions in a system of linear equations.

Example (trivial least squares approx.):

Given a set of samples (x[i], y[i]) which are supposed to lie on a line given by 'y = m*x + a'. That is, the estimated line is determined by 'm' and 'a'. The error between the samples and the estimated line may be given by

 error(m, a) = sum((m * x[i] + a - y[i]) ** 2, start_i=0, end_i=N)

Now, searching for the zero transitions in the derivatives 'd/dm error(m,a)' and 'd/da error(m,a)' delivers the minimal distance. How could I find the solution with sympy?

Upvotes: 9

Views: 7155

Answers (2)

Frank-Rene Schäfer
Frank-Rene Schäfer

Reputation: 3352

The actual solution is to used IndexedBased symbols and an integer symbol for the index of the summation, and the Sum instance with appropriate setup, i.e.

from sympy import symbols, IndexedBase, Sum, expand

x = IndexedBase('x')
y = IndexedBase('y')
n, m, a = symbols('n m a')
i = symbols('i', integer=True)

expr = Sum((m * x[i] + a - y[i]) ** 2, (i, 1, n)) 

From there, you can make any modifications as you would do on paper, namely for example:

print(expand(expr))

Upvotes: 1

Uriel
Uriel

Reputation: 16184

Given your later question, I assume you already figured most of it, but for clarity sake, samples are considered as function (makes sense, given sets are actually functions that cover the domain of the set [mostly over part of the integers]), so the notation is like x(i), and summation can be achieved with the summation function or Sum constructor (the first one is better, since it will expand automatically constant addends, like summation(x, (i, 0, n))).

>>> from sympy import *
>>> m, a, x, y, i, n = symbols('m a x y i n')
>>> err = summation((m * x(i) + a - y(i)) ** 2, (i, 0, n))
>>> pprint(err)
  n
 ___
 ╲
  ╲                      2
  ╱   (a + m⋅x(i) - y(i))
 ╱
 ‾‾‾
i = 0

After you provide the sum function the addend expression and the (index, lower bound, upper bound), you can move on to play with the sum:

>>> diff(err, m)
Sum(2*(a + m*x(i) - y(i))*x(i), (i, 0, n))
>>> diff(err, a)
Sum(2*a + 2*m*x(i) - 2*y(i), (i, 0, n))

Upvotes: 9

Related Questions