Reputation: 1
I am trying to sample from a distribution defined on matrix of size 4x3 using NUTS sampler. Here is my code :
import pymc as pm
import pytensor.tensor as pt
import pytensor.tensor.linalg as ptl
import numpy as np
normal=pm.MatrixNormal.dist(mu=np.eye(4,3),rowcov=np.eye(4),colcov=np.eye(3),shape=(4,3))
def logp(value):
z=ptl.matrix_dot(value.T,value)
w,v = ptl.eigh(z)
w_invsqrt=1/pt.sqrt(w)
inv_sqrt=ptl.matrix_dot(v,pt.diag(w_invsqrt),v.T)
QX=ptl.matrix_dot(value, inv_sqrt)
return -pt.trace(QX)**2+pm.logp(normal,QX)
with pm.Model() as model:
Q = pm.DensityDist("Q" , logp=logp, ndim_supp=2,shape=(4,3))
step = pm.NUTS()
trace = pm.sample(draws=100, tune=100, chains=1,step=step, nuts_sampler='pymc', discard_tuned_samples=False)
I have pymc version 5.13.1 and pytensor version 2.20.0.
I get the following error :
/opt/anaconda3/envs/pymc_env/lib/python3.11/site-packages/pytensor/tensor/nlinalg.py:181: FutureWarning: pytensor.tensor.linalg.trace is deprecated. Use pytensor.tensor.trace instead.
warnings.warn(
/opt/anaconda3/envs/pymc_env/lib/python3.11/site-packages/pytensor/tensor/nlinalg.py:181: FutureWarning: pytensor.tensor.linalg.trace is deprecated. Use pytensor.tensor.trace instead.
warnings.warn(
/opt/anaconda3/envs/pymc_env/lib/python3.11/site-packages/pytensor/tensor/nlinalg.py:181: FutureWarning: pytensor.tensor.linalg.trace is deprecated. Use pytensor.tensor.trace instead.
warnings.warn(
Traceback (most recent call last):
File "/opt/anaconda3/envs/pymc_env/lib/python3.11/site-packages/pytensor/compile/function/types.py", line 970, in __call__
self.vm()
File "/opt/anaconda3/envs/pymc_env/lib/python3.11/site-packages/pytensor/graph/op.py", line 515, in rval
r = p(n, [x[0] for x in i], o)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/anaconda3/envs/pymc_env/lib/python3.11/site-packages/pytensor/tensor/slinalg.py", line 298, in perform
outputs[0][0] = scipy.linalg.solve_triangular(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/anaconda3/envs/pymc_env/lib/python3.11/site-packages/scipy/linalg/_basic.py", line 334, in solve_triangular
b1 = _asarray_validated(b, check_finite=check_finite)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/anaconda3/envs/pymc_env/lib/python3.11/site-packages/scipy/_lib/_util.py", line 321, in _asarray_validated
a = toarray(a)
^^^^^^^^^^
File "/opt/anaconda3/envs/pymc_env/lib/python3.11/site-packages/numpy/lib/function_base.py", line 630, in asarray_chkfinite
raise ValueError(
ValueError: array must not contain infs or NaNs
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/Users/maximescali/Desktop/Research/MScthesis/sampling/modelsv1/pymc_inv.py", line 21, in <module>
trace = pm.sample(draws=100, tune=100, chains=1,step=step, nuts_sampler='pymc', discard_tuned_samples=False)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/anaconda3/envs/pymc_env/lib/python3.11/site-packages/pymc/sampling/mcmc.py", line 740, in sample
model.check_start_vals(ip)
File "/opt/anaconda3/envs/pymc_env/lib/python3.11/site-packages/pymc/model/core.py", line 1707, in check_start_vals
initial_eval = self.point_logps(point=elem)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/anaconda3/envs/pymc_env/lib/python3.11/site-packages/pymc/model/core.py", line 1742, in point_logps
self.compile_fn(factor_logps_fn)(point),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/anaconda3/envs/pymc_env/lib/python3.11/site-packages/pymc/pytensorf.py", line 590, in __call__
return self.f(**state)
^^^^^^^^^^^^^^^
File "/opt/anaconda3/envs/pymc_env/lib/python3.11/site-packages/pytensor/compile/function/types.py", line 983, in __call__
raise_with_op(
File "/opt/anaconda3/envs/pymc_env/lib/python3.11/site-packages/pytensor/link/utils.py", line 523, in raise_with_op
raise exc_value.with_traceback(exc_trace)
File "/opt/anaconda3/envs/pymc_env/lib/python3.11/site-packages/pytensor/compile/function/types.py", line 970, in __call__
self.vm()
File "/opt/anaconda3/envs/pymc_env/lib/python3.11/site-packages/pytensor/graph/op.py", line 515, in rval
r = p(n, [x[0] for x in i], o)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/anaconda3/envs/pymc_env/lib/python3.11/site-packages/pytensor/tensor/slinalg.py", line 298, in perform
outputs[0][0] = scipy.linalg.solve_triangular(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/anaconda3/envs/pymc_env/lib/python3.11/site-packages/scipy/linalg/_basic.py", line 334, in solve_triangular
b1 = _asarray_validated(b, check_finite=check_finite)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/anaconda3/envs/pymc_env/lib/python3.11/site-packages/scipy/_lib/_util.py", line 321, in _asarray_validated
a = toarray(a)
^^^^^^^^^^
File "/opt/anaconda3/envs/pymc_env/lib/python3.11/site-packages/numpy/lib/function_base.py", line 630, in asarray_chkfinite
raise ValueError(
ValueError: array must not contain infs or NaNs
Apply node that caused the error: SolveTriangular{trans=0, unit_diagonal=False, lower=True, check_finite=True, b_ndim=2}([[1. 0. 0. ... 0. 0. 1.]], Sub.0)
Toposort index: 13
Inputs types: [TensorType(float64, shape=(4, 4)), TensorType(float64, shape=(4, 3))]
Inputs shapes: [(4, 4), (4, 3)]
Inputs strides: [(8, 32), (24, 8)]
Inputs values: ['not shown', 'not shown']
Outputs clients: [[Transpose{axes=[1, 0]}(SolveTriangular{trans=0, unit_diagonal=False, lower=True, check_finite=True, b_ndim=2}.0), Dot22(Transpose{axes=[1, 0]}.0, SolveTriangular{trans=0, unit_diagonal=False, lower=True, check_finite=True, b_ndim=2}.0)]]
Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
File "/opt/anaconda3/envs/pymc_env/lib/python3.11/functools.py", line 909, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
File "/opt/anaconda3/envs/pymc_env/lib/python3.11/site-packages/pymc/distributions/distribution.py", line 826, in custom_dist_logp
return logp(values[0], *dist_params)
File "/Users/maximescali/Desktop/Research/MScthesis/sampling/modelsv1/pymc_inv.py", line 14, in logp
return -pt.trace(QX)**2+pm.logp(normal,QX)
File "/opt/anaconda3/envs/pymc_env/lib/python3.11/site-packages/pymc/logprob/basic.py", line 211, in logp
return _logprob_helper(rv, value, **kwargs)
File "/opt/anaconda3/envs/pymc_env/lib/python3.11/site-packages/pymc/logprob/abstract.py", line 68, in _logprob_helper
logprob = _logprob(rv.owner.op, values, *rv.owner.inputs, **kwargs)
File "/opt/anaconda3/envs/pymc_env/lib/python3.11/functools.py", line 909, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
File "/opt/anaconda3/envs/pymc_env/lib/python3.11/site-packages/pymc/distributions/distribution.py", line 213, in logp
return class_logp(value, *dist_params)
File "/opt/anaconda3/envs/pymc_env/lib/python3.11/site-packages/pymc/distributions/multivariate.py", line 1876, in logp
right_quaddist = solve_lower(rowchol, delta)
HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.
I suspect it coud be because the matrix $X^TX$ might not be positive definite if X is the null matrix (at initialisation). However I do not have a clear intuition of what is going on from the error message since my logp takes as input and outputs a 'pytensor.tensor.variable.TensorVariable' (it does not contain a value). When I set
exception_verbosity=high
as advised, I get the Pytensor graph of the problematic node but it still does not show me the value of the nodes. I also tried to replace
step = pm.NUTS()
by
_,step = pm.init_nuts(initvals={"Q":np.eye(4,3)})
by it does not change the error.
I would appreciate any help !
Upvotes: 0
Views: 49