bbuzz31
bbuzz31

Reputation: 140

Sympy: Drop terms without a specific variable

I'm trying to compute some (a lot) of multivariate conditional densities (i.e., the multiplication of several multivariate probability density functions). I'm able to set up and expand the matrices properly but now would like to drop terms that, for example in the equation (and code) below, don't contain wg. With help from the posted answer, I was able to develop a hacky solution; improvements are welcome.

UPDATE: MWE

import sympy as sym
from IPython.display import display as disp
N  = 211
wg = sym.MatrixSymbol('w_g', N, 1)
wg_n = sym.MatrixSymbol('w_gn', N, 1)
Z_wg = sym.MatrixSymbol('Z_wg', N, N)

# pdf wg
pdf_wg = ((wg - wg_n).T * Z_wg.I * (wg - wg_n))
pdf_full = sym.expand(pdf_wg)
# pdf_full.collect(wg) # NotImplementedError: noncommutative scalars in MatMul are not supported.


# print (wg in pdf_full.atoms()) # False

# this gives what I want
terms = pdf_full.as_terms()[0]
for term in terms:
    if 'w_g,' in str(term[0].atoms()):
        disp (term[0])

UPDATE 2: More Complex MWE

Here I'm trying to grab just the terms with b in them.

import sympy as sym
from IPython.display import display as disp, Math

mu   = sym.symbols('mu')               # mean non GIA SSH trend
N    = 211
vec1 = sym.MatrixSymbol('1', N, 1)
u    = sym.MatrixSymbol('u', N, 1)     
Pi   = sym.MatrixSymbol('Pi', N, N)    
b    = sym.MatrixSymbol('b', N, 1)

wg = sym.MatrixSymbol('w_g', N, 1)     
wm = sym.MatrixSymbol('w_m', N, 1)     
bhat = mu*vec1 + wg + wm + u # convenience

pdf  = sym.expand((b - bhat).T * Pi.I * (b-bhat))
terms      = pdf.as_terms()[0]
good_terms = [] 
for term in terms:
    if b.args[0] in term[0].atoms():
        good_terms.append(term[0])

print ('Good terms:'); disp(sym.Add(*good_terms))

UPDATE 4: Solved For more complex expressions adding doit() to the expand will prevent a bunch of extra loops (e.g.):

pdf  = sym.expand((b - bhat).T * Pi.I * (b-bhat)).doit()

More information can be found in the comments to the various answers.

Thanks!

enter image description here

Upvotes: 0

Views: 190

Answers (2)

Oscar Benjamin
Oscar Benjamin

Reputation: 14500

You can get the terms not containing wg like:

In [53]: pdf_full.subs(wg, ZeroMatrix(N, 1)).doit()
Out[53]: 
    T     -1     
w_gn ⋅Z_wg  ⋅w_gn

Then you can subtract those from pdf_full:

In [54]: pdf_full - pdf_full.subs(wg, ZeroMatrix(N, 1)).doit()
Out[54]: 
   T     -1         T     -1           T     -1    
w_g ⋅Z_wg  ⋅w_g -w_g ⋅Z_wg  ⋅w_gn -w_gn ⋅Z_wg  ⋅w_g

Upvotes: 1

JohanC
JohanC

Reputation: 80399

You could extract the atoms of the expression and test whether the variable is among them:

from sympy import symbols

a, b, mug = symbols('a b mu_g')
expr1 = a * b + a * mug
expr2 = a * b
for expr in [expr1, expr2]:
    if mug in expr.atoms():
        print(expr, 'contains', mug)
    else:
        print(expr, 'does not contain', mug)

PS: An update for your new question. For a MatrixSymbol the symbol is stored as wg.args[0] (args[1] and args[2] are the dimensions):

import sympy as sym

N  = 211
wg = sym.MatrixSymbol('w_g', N, 1)
wg_n = sym.MatrixSymbol('w_gn', N, 1)
Z_wg = sym.MatrixSymbol('Z_wg', N, N)

pdf_wg = ((wg - wg_n).T * Z_wg.I * (wg - wg_n))
pdf_full = sym.expand(pdf_wg)

print (wg.args[0] in pdf_full.atoms()) # True

Note that the hacky solution is the question could go wrong when w_g would be the last item or another name would end in the same string.

Upvotes: 1

Related Questions