Erotemic
Erotemic

Reputation: 5248

How to decompose a 2x2 affine matrix with sympy?

I'm attempting to show the decomposition of an affine matrix with sympy as shown in the following stackexchange post:

https://math.stackexchange.com/questions/612006/decomposing-an-affine-transformation

I've setup two matrices A_params and A_matrix, where the former represents the raw matrix values and the latter is the matrix constructed from its underlying parameters.

import sympy
import itertools as it
import ubelt as ub
domain = {'real': True}

theta = sympy.symbols('theta', **domain)
sx, sy = sympy.symbols('sx, sy', nonzero=True, **domain)
m = sympy.symbols('m', **domain)

S = sympy.Matrix([  # scale
    [sx,  0],
    [ 0, sy]])

H = sympy.Matrix([  # shear
    [1, m],
    [0, 1]])

R = sympy.Matrix([  # rotation
    [sympy.cos(theta), -sympy.sin(theta)],
    [sympy.sin(theta),  sympy.cos(theta)]])


A_params = sympy.simplify((R @ H @ S))
a11, a12, a21, a22 = sympy.symbols(
    'a11, a12, a21, a22', real=True)
A_matrix = sympy.Matrix([[a11, a12], [a21, a22]])


print(ub.hzcat(['A_matrix = ', sympy.pretty(A_matrix)]))
print(ub.hzcat(['A_params = ', sympy.pretty(A_params)]))
A_matrix = ⎡a₁₁  a₁₂⎤
           ⎢        ⎥
           ⎣a₂₁  a₂₂⎦
A_params = ⎡sx⋅cos(θ)  sy⋅(m⋅cos(θ) - sin(θ))⎤
           ⎢                                 ⎥
           ⎣sx⋅sin(θ)  sy⋅(m⋅sin(θ) + cos(θ))⎦

From what I understand I should simply be able to set these two matrices to be equal and then solve for the parameters of interest. However, I'm getting unexpected results.

First, if I just try to solve for "sx", I get no result.

## Option 1: Matrix equality
mat_equation = sympy.Eq(A_matrix, A_params)
soln_sx = sympy.solve(mat_equation, sx)
print('soln_sx = {!r}'.format(soln_sx))

## Option 2: List of equations
lhs_iter = it.chain.from_iterable(A_matrix.tolist())
rhs_iter = it.chain.from_iterable(A_params.tolist())
equations = [sympy.Eq(lhs, rhs) for lhs, rhs in zip(lhs_iter, rhs_iter)]
soln_sx = sympy.solve(equations, sx)
print('soln_sx = {!r}'.format(soln_sx))
soln_sx = []
soln_sx = []

But if I try to solve for all variables simultaniously, I get a result but it does not agree with what I would expect

solve_for = (sx, theta, sy, m)
solutions = sympy.solve(mat_equation, *solve_for)
for sol, symbol in zip(solutions[0], solve_for):
    sol = sympy.simplify(sol)
    print('sol({!r}) = {!r}'.format(symbol, sol))
    # sympy.pretty_print(sol)
sol(sx) = -(a11**2 + a11*sqrt(a11**2 + a21**2) + a21**2)/(a11 + sqrt(a11**2 + a21**2))
sol(theta) = -2*atan((a11 + sqrt(a11**2 + a21**2))/a21)
sol(sy) = (-8*a11**6*a22 + 8*a11**5*a12*a21 - 8*a11**5*a22*sqrt(a11**2 + a21**2) + 8*a11**4*a12*a21*sqrt(a11**2 + a21**2) - 12*a11**4*a21**2*a22 + 12*a11**3*a12*a21**3 - 8*a11**3*a21**2*a22*sqrt(a11**2 + a21**2) + 8*a11**2*a12*a21**3*sqrt(a11**2 + a21**2) - 4*a11**2*a21**4*a22 + 4*a11*a12*a21**5 - a11*a21**4*a22*sqrt(a11**2 + a21**2) + a12*a21**5*sqrt(a11**2 + a21**2))/(8*a11**6 + 8*a11**5*sqrt(a11**2 + a21**2) + 16*a11**4*a21**2 + 12*a11**3*a21**2*sqrt(a11**2 + a21**2) + 9*a11**2*a21**4 + 4*a11*a21**4*sqrt(a11**2 + a21**2) + a21**6)
sol(m) = (a11*a12 + a21*a22)/(a11*a22 - a12*a21)

After having a hard time getting the above to work, I wanted to see if I could at least verify the solution from stackexchange. So I coded that up symbolically:

# This is the guided solution by Stéphane Laurent
recon_sx = sympy.sqrt(a11 * a11 + a21 * a21)
recon_theta = sympy.atan2(a21, a11)
recon_sin_t = sympy.sin(recon_theta)
recon_cos_t = sympy.cos(recon_theta)

recon_msy = a12 * recon_sin_t + a22 * recon_cos_t

condition2 = sympy.simplify(sympy.Eq(recon_sin_t, 0))
condition1 = sympy.simplify(sympy.Not(condition2))
sy_cond1 = (recon_msy * recon_cos_t - a12) / recon_sin_t
sy_cond2 = (a22 - recon_msy * recon_sin_t) / recon_cos_t

recon_sy = sympy.Piecewise((sy_cond1, condition1), (sy_cond2, condition2))

recon_m = recon_msy / recon_sy

recon_S = sympy.Matrix([  # scale
    [recon_sx,  0],
    [ 0, recon_sy]])

recon_H = sympy.Matrix([  # shear
    [1, recon_m],
    [0, 1]])

recon_R = sympy.Matrix([  # rotation
    [sympy.cos(recon_theta), -sympy.sin(recon_theta)],
    [sympy.sin(recon_theta),  sympy.cos(recon_theta)]])

# Recombine the components
A_recon = sympy.simplify((recon_R @ recon_H @ recon_S))
print(ub.hzcat(['A_recon = ', sympy.pretty(A_recon)]))

That results in something quite like what I would expect, but it doesn't seem to simplify all the way down to the point where it can be programmatically validated.

A_recon = ⎡     ⎧                                       a₂₁            ⎤
          ⎢     ⎪            a₁₂              for ──────────────── ≠ 0 ⎥
          ⎢     ⎪                                    _____________     ⎥
          ⎢     ⎪                                   ╱    2      2      ⎥
          ⎢a₁₁  ⎨                                 ╲╱  a₁₁  + a₂₁       ⎥
          ⎢     ⎪                                                      ⎥
          ⎢     ⎪a₁₁⋅a₂₂ + a₁₂⋅a₂₁ - a₂₁⋅a₂₂                           ⎥
          ⎢     ⎪───────────────────────────         otherwise         ⎥
          ⎢     ⎩            a₁₁                                       ⎥
          ⎢                                                            ⎥
          ⎢     ⎧-a₁₁⋅a₁₂ + a₁₁⋅a₂₂ + a₁₂⋅a₂₁            a₂₁           ⎥
          ⎢     ⎪────────────────────────────  for ──────────────── ≠ 0⎥
          ⎢     ⎪            a₂₁                      _____________    ⎥
          ⎢a₂₁  ⎨                                    ╱    2      2     ⎥
          ⎢     ⎪                                  ╲╱  a₁₁  + a₂₁      ⎥
          ⎢     ⎪                                                      ⎥
          ⎣     ⎩            a₂₂                      otherwise        ⎦

My thought is that the conditional is messing is up, so I tried just using two cases:

recon_sy2 = sy_cond1
recon_m2 = recon_msy / recon_sy2

recon_S2 = sympy.Matrix([  # scale
    [recon_sx,  0],
    [ 0, recon_sy2]])

recon_H2 = sympy.Matrix([  # shear
    [1, recon_m2],
    [0, 1]])


recon_sy3 = sy_cond2
recon_m3 = recon_msy / recon_sy3

recon_S3 = sympy.Matrix([  # scale
    [recon_sx,  0],
    [ 0, recon_sy3]])

recon_H3 = sympy.Matrix([  # shear
    [1, recon_m3],
    [0, 1]])


# Recombine the components
A_recon2 = sympy.simplify((recon_R @ recon_H2 @ recon_S2))
A_recon3 = sympy.simplify((recon_R @ recon_H3 @ recon_S3))
print('')
print(ub.hzcat(['A_recon2 = ', sympy.pretty(A_recon2)]))
print('')
print(ub.hzcat(['A_recon3 = ', sympy.pretty(A_recon3)]))
A_recon2 = ⎡a₁₁              a₁₂             ⎤
           ⎢                                 ⎥
           ⎢     -a₁₁⋅a₁₂ + a₁₁⋅a₂₂ + a₁₂⋅a₂₁⎥
           ⎢a₂₁  ────────────────────────────⎥
           ⎣                 a₂₁             ⎦

A_recon3 = ⎡     a₁₁⋅a₂₂ + a₁₂⋅a₂₁ - a₂₁⋅a₂₂⎤
           ⎢a₁₁  ───────────────────────────⎥
           ⎢                 a₁₁            ⎥
           ⎢                                ⎥
           ⎣a₂₁              a₂₂            ⎦

But that doesn't seem to allow any further simplification.

I'm not quite seeing how a22/a12 pops out of the top/bottom equations respectively, but they should if this decomposition is correct, but these results are making me worried that it is not.

So my questions are two fold:

  1. Can any sympy gurus help me get the basic solution for the decomposition working?

  2. Is the decomposition in the reference SE post wrong? Or am I not including a constraint that would allow simplification? If so how would I do that in sympy?

Update

I was able to get a bit further by using sympy.radsimp on the equations from sympy.solve when all variables are solved for jointly (still not sure why it wont solve for sx by itself).

solve_for = (sx, theta, sy, m)
solutions = sympy.solve(mat_equation, *solve_for, dict=True)
# minimal=True, quick=True, cubics=False, quartics=False, quintics=False, check=False)
for sol in solutions:
    for sym, symsol0 in sol.items():
        symsol = sympy.radsimp(symsol0)
        symsol = sympy.trigsimp(symsol)
        symsol = sympy.simplify(symsol)
        symsol = sympy.radsimp(symsol)
        print('\n=====')
        print('sym = {!r}'.format(sym))
        print('symsol  = {!r}'.format(symsol))
        print('--')
        sympy.pretty_print(symsol, wrap_line=False)
        print('--')
        print('=====\n')
=====
sym = sx
symsol  = -sqrt(a11**2 + a21**2)
--
    _____________
   ╱    2      2 
-╲╱  a₁₁  + a₂₁  
--
=====


=====
sym = theta
symsol  = 2*atan((a11 + sqrt(a11**2 + a21**2))/a21)
--
      ⎛         _____________⎞
      ⎜        ╱    2      2 ⎟
      ⎜a₁₁ + ╲╱  a₁₁  + a₂₁  ⎟
2⋅atan⎜──────────────────────⎟
      ⎝         a₂₁          ⎠
--
=====


=====
sym = m
symsol  = (a11*a12 + a21*a22)/(a11*a22 - a12*a21)
--
a₁₁⋅a₁₂ + a₂₁⋅a₂₂
─────────────────
a₁₁⋅a₂₂ - a₁₂⋅a₂₁
--
=====


=====
sym = sy
symsol  = (-a11*a22*sqrt(a11**2 + a21**2) + a12*a21*sqrt(a11**2 + a21**2))/(a11**2 + a21**2)
--
             _____________              _____________
            ╱    2      2              ╱    2      2 
- a₁₁⋅a₂₂⋅╲╱  a₁₁  + a₂₁   + a₁₂⋅a₂₁⋅╲╱  a₁₁  + a₂₁  
─────────────────────────────────────────────────────
                        2      2                     
                     a₁₁  + a₂₁                      
--
=====

But the solution for sx is closer to what I want (althought it's a negative root, which I suppose is technically correct, but I was under the impression sympy only handled principle roots).

Main questions are still open though. (although I'm more confident the original SE post is correct).

And it looks like it is saying that "m" has the determinant in the denominator, which is interesting. (and the numerator is the dot-product of the rows).

Update2

I'm starting to think that there is some error in sympy or in the Se post. I started doing numerical checks, and it's giving errors that I don't think are reconcilable (i.e. same after rotation).

Numerical checking code is

params = [sx, theta, sy, m]
params_rand = {p: np.random.rand() for p in params}
A_params_rand = A_params.subs(params_rand)
matrix_rand = {lhs: rhs for lhs, rhs in zip(elements, ub.flatten(A_params_rand.tolist()))}
A_matrix_rand = A_matrix.subs(matrix_rand)
A_solved_rand = A_solved_recon.subs(matrix_rand)
A_recon_rand = A_recon.subs(matrix_rand)

mat1 = np.array(A_matrix_rand.tolist()).astype(float)
mat2 = np.array(A_params_rand.tolist()).astype(float)
mat3 = np.array(A_recon_rand.tolist()).astype(float)
assert np.all(np.isclose(mat1, mat2))

print(mat2 - mat3)

mat4 = np.array(A_solved_rand.tolist()).astype(float)

Random values seem to always produce some error at a22 in the matrix, so I think the sympy reconstruction of the matrix from the manually inputed decomposition is wrong, or the decomposition itself is wrong. Any help would be very valuable.

Upvotes: 1

Views: 412

Answers (1)

Erotemic
Erotemic

Reputation: 5248

After discussion with a colleague, it turns out I made a simple error in the code. I swapped sin and cos terms. Fixing this results in the correct reconstruction of the matrix when using @Stéphane Laurent's decomposition:

import sympy
import ubelt as ub

domain = {'real': True}

theta = sympy.symbols('theta', **domain)
sx, sy = sympy.symbols('sx, sy', **domain)
m = sympy.symbols('m', **domain)
params = [sx, theta, sy, m]

S = sympy.Matrix([  # scale
    [sx,  0],
    [ 0, sy]])

H = sympy.Matrix([  # shear
    [1, m],
    [0, 1]])

R = sympy.Matrix((  # rotation
    [sympy.cos(theta), -sympy.sin(theta)],
    [sympy.sin(theta),  sympy.cos(theta)]))

A_params = sympy.simplify((R @ H @ S))
a11, a12, a21, a22 = sympy.symbols(
    'a11, a12, a21, a22', real=True)
A_matrix = sympy.Matrix(((a11, a12), (a21, a22)))

print(ub.hzcat(['A_matrix = ', sympy.pretty(A_matrix)]))
print(ub.hzcat(['A_params = ', sympy.pretty(A_params)]))


# This is the guided solution by Stéphane Laurent
recon_sx = sympy.sqrt(a11 * a11 + a21 * a21)
recon_theta = sympy.atan2(a21, a11)
recon_sin_t = sympy.sin(recon_theta)
recon_cos_t = sympy.cos(recon_theta)

recon_msy = a12 * recon_cos_t + a22 * recon_sin_t


# condition2 = sympy.simplify(sympy.Eq(recon_sin_t, 0))
# condition1 = sympy.simplify(sympy.Not(condition2))
condition1 = sympy.Gt(recon_sin_t ** 2, recon_cos_t ** 2)
condition2 = sympy.Le(recon_sin_t ** 2, recon_cos_t ** 2)
sy_cond1 = (recon_msy * recon_cos_t - a12) / recon_sin_t
sy_cond2 = (a22 - recon_msy * recon_sin_t) / recon_cos_t
recon_sy = sympy.Piecewise((sy_cond1, condition1), (sy_cond2, condition2))
recon_m = sympy.simplify(recon_msy / recon_sy)


# Substitute the decomposition into the "A_params" to reconstruct "A_matrix"
recon_symbols = {
    sx: recon_sx,
    theta: recon_theta,
    m: recon_m,
    sy: recon_sy
}

for sym, symval in recon_symbols.items():
    # symval = sympy.radsimp(symval)
    symval = sympy.trigsimp(symval)
    symval = sympy.simplify(symval)
    if not isinstance(symval, sympy.Piecewise):
        symval = sympy.radsimp(symval)
    print('\n=====')
    print('sym = {!r}'.format(sym))
    print('symval  = {!r}'.format(symval))
    print('--')
    sympy.pretty_print(symval)
    print('=====\n')

A_recon = A_params.subs(recon_symbols)
A_recon = sympy.simplify(A_recon)
print(ub.hzcat(['A_recon = ', sympy.pretty(A_recon)]))

Output of reconstruction with Laurent's explicitly defined decomposition:

A_matrix = ⎡a₁₁  a₁₂⎤
           ⎢        ⎥
           ⎣a₂₁  a₂₂⎦
A_params = ⎡sx⋅cos(θ)  sy⋅(m⋅cos(θ) - sin(θ))⎤
           ⎢                                 ⎥
           ⎣sx⋅sin(θ)  sy⋅(m⋅sin(θ) + cos(θ))⎦

=====
sym = sx
symval  = sqrt(a11**2 + a21**2)
--
   _____________
  ╱    2      2
╲╱  a₁₁  + a₂₁
=====


=====
sym = theta
symval  = atan2(a21, a11)
--
atan2(a₂₁, a₁₁)
=====


=====
sym = m
symval  = (a11*a12 + a21*a22)/(a11*a22 - a12*a21)
--
a₁₁⋅a₁₂ + a₂₁⋅a₂₂
─────────────────
a₁₁⋅a₂₂ - a₁₂⋅a₂₁
=====


=====
sym = sy
symval  = (a11*a22*sqrt(a11**2 + a21**2) - a12*a21*sqrt(a11**2 + a21**2))/(a11**2 + a21**2)
--
           _____________              _____________
          ╱    2      2              ╱    2      2
a₁₁⋅a₂₂⋅╲╱  a₁₁  + a₂₁   - a₁₂⋅a₂₁⋅╲╱  a₁₁  + a₂₁
───────────────────────────────────────────────────
                       2      2
                    a₁₁  + a₂₁
=====

A_recon = ⎡a₁₁  a₁₂⎤
          ⎢        ⎥
          ⎣a₂₁  a₂₂⎦

I was also able to get the solver to produce a solution that reconstructed "A_matrix" correctly, although I had to jump through some hoops, and the decomposition takes a different (somewhat strange) form. But it does produce the right answer:

mat_equation = sympy.Eq(A_matrix, A_params)
solve_for = (sx, theta, sy, m)
solutions = sympy.solve(mat_equation, *solve_for, dict=True)
solved = {}
# minimal=True, quick=True, cubics=False, quartics=False, quintics=False, check=False)
for sol in solutions:
    for sym, symsol0 in sol.items():
        symsol = sympy.radsimp(symsol0)
        symsol = sympy.trigsimp(symsol)
        symsol = sympy.simplify(symsol)
        symsol = sympy.radsimp(symsol)
        print('\n=====')
        print('sym = {!r}'.format(sym))
        print('symsol  = {!r}'.format(symsol))
        print('--')
        sympy.pretty_print(symsol, wrap_line=False)
        solved[sym] = symsol
        print('--')
        print('=====\n')

    A_matrix[0, :].dot(A_matrix[1, :]) / A_matrix.det()

A_solved_recon = sympy.simplify(A_params.subs(solved))

print(ub.hzcat(['A_solved_recon = ', sympy.pretty(A_solved_recon)]))

Although I haven't worked out all of the details, it does seem that this sympy-computed decomposition is correct:

=====
sym = sx
symsol  = -sqrt(a11**2 + a21**2)
--
    _____________
   ╱    2      2 
-╲╱  a₁₁  + a₂₁  
--
=====


=====
sym = theta
symsol  = -2*atan((a11 + sqrt(a11**2 + a21**2))/a21)
--
       ⎛         _____________⎞
       ⎜        ╱    2      2 ⎟
       ⎜a₁₁ + ╲╱  a₁₁  + a₂₁  ⎟
-2⋅atan⎜──────────────────────⎟
       ⎝         a₂₁          ⎠
--
=====


=====
sym = m
symsol  = (a11*a12 + a21*a22)/(a11*a22 - a12*a21)
--
a₁₁⋅a₁₂ + a₂₁⋅a₂₂
─────────────────
a₁₁⋅a₂₂ - a₁₂⋅a₂₁
--
=====


=====
sym = sy
symsol  = (-a11*a22*sqrt(a11**2 + a21**2) + a12*a21*sqrt(a11**2 + a21**2))/(a11**2 + a21**2)
--
             _____________              _____________
            ╱    2      2              ╱    2      2 
- a₁₁⋅a₂₂⋅╲╱  a₁₁  + a₂₁   + a₁₂⋅a₂₁⋅╲╱  a₁₁  + a₂₁  
─────────────────────────────────────────────────────
                        2      2                     
                     a₁₁  + a₂₁                      
--
=====

A_solved_recon = ⎡a₁₁  a₁₂⎤
                 ⎢        ⎥
                 ⎣a₂₁  a₂₂⎦

Upvotes: 0

Related Questions