Maxime
Maxime

Reputation: 1

Custom density with NUTS sampler using Pymc

I am trying to sample from a distribution defined on matrix of size 4x3 using NUTS sampler. Here is my code :

import pymc as pm
import pytensor.tensor as pt
import pytensor.tensor.linalg as ptl
import numpy as np

normal=pm.MatrixNormal.dist(mu=np.eye(4,3),rowcov=np.eye(4),colcov=np.eye(3),shape=(4,3))

def logp(value):
    z=ptl.matrix_dot(value.T,value)
    w,v = ptl.eigh(z)
    w_invsqrt=1/pt.sqrt(w)
    inv_sqrt=ptl.matrix_dot(v,pt.diag(w_invsqrt),v.T)
    QX=ptl.matrix_dot(value, inv_sqrt)
    return -pt.trace(QX)**2+pm.logp(normal,QX)


with pm.Model() as model:
    Q = pm.DensityDist("Q" , logp=logp, ndim_supp=2,shape=(4,3))
    step = pm.NUTS()
    trace = pm.sample(draws=100, tune=100, chains=1,step=step, nuts_sampler='pymc', discard_tuned_samples=False)

I have pymc version 5.13.1 and pytensor version 2.20.0.

I get the following error :

/opt/anaconda3/envs/pymc_env/lib/python3.11/site-packages/pytensor/tensor/nlinalg.py:181: FutureWarning: pytensor.tensor.linalg.trace is deprecated. Use pytensor.tensor.trace instead.
  warnings.warn(
/opt/anaconda3/envs/pymc_env/lib/python3.11/site-packages/pytensor/tensor/nlinalg.py:181: FutureWarning: pytensor.tensor.linalg.trace is deprecated. Use pytensor.tensor.trace instead.
  warnings.warn(
/opt/anaconda3/envs/pymc_env/lib/python3.11/site-packages/pytensor/tensor/nlinalg.py:181: FutureWarning: pytensor.tensor.linalg.trace is deprecated. Use pytensor.tensor.trace instead.
  warnings.warn(
Traceback (most recent call last):
  File "/opt/anaconda3/envs/pymc_env/lib/python3.11/site-packages/pytensor/compile/function/types.py", line 970, in __call__
    self.vm()
  File "/opt/anaconda3/envs/pymc_env/lib/python3.11/site-packages/pytensor/graph/op.py", line 515, in rval
    r = p(n, [x[0] for x in i], o)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/pymc_env/lib/python3.11/site-packages/pytensor/tensor/slinalg.py", line 298, in perform
    outputs[0][0] = scipy.linalg.solve_triangular(
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/pymc_env/lib/python3.11/site-packages/scipy/linalg/_basic.py", line 334, in solve_triangular
    b1 = _asarray_validated(b, check_finite=check_finite)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/pymc_env/lib/python3.11/site-packages/scipy/_lib/_util.py", line 321, in _asarray_validated
    a = toarray(a)
        ^^^^^^^^^^
  File "/opt/anaconda3/envs/pymc_env/lib/python3.11/site-packages/numpy/lib/function_base.py", line 630, in asarray_chkfinite
    raise ValueError(
ValueError: array must not contain infs or NaNs

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/Users/maximescali/Desktop/Research/MScthesis/sampling/modelsv1/pymc_inv.py", line 21, in <module>
    trace = pm.sample(draws=100, tune=100, chains=1,step=step, nuts_sampler='pymc', discard_tuned_samples=False)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/pymc_env/lib/python3.11/site-packages/pymc/sampling/mcmc.py", line 740, in sample
    model.check_start_vals(ip)
  File "/opt/anaconda3/envs/pymc_env/lib/python3.11/site-packages/pymc/model/core.py", line 1707, in check_start_vals
    initial_eval = self.point_logps(point=elem)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/pymc_env/lib/python3.11/site-packages/pymc/model/core.py", line 1742, in point_logps
    self.compile_fn(factor_logps_fn)(point),
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/pymc_env/lib/python3.11/site-packages/pymc/pytensorf.py", line 590, in __call__
    return self.f(**state)
           ^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/pymc_env/lib/python3.11/site-packages/pytensor/compile/function/types.py", line 983, in __call__
    raise_with_op(
  File "/opt/anaconda3/envs/pymc_env/lib/python3.11/site-packages/pytensor/link/utils.py", line 523, in raise_with_op
    raise exc_value.with_traceback(exc_trace)
  File "/opt/anaconda3/envs/pymc_env/lib/python3.11/site-packages/pytensor/compile/function/types.py", line 970, in __call__
    self.vm()
  File "/opt/anaconda3/envs/pymc_env/lib/python3.11/site-packages/pytensor/graph/op.py", line 515, in rval
    r = p(n, [x[0] for x in i], o)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/pymc_env/lib/python3.11/site-packages/pytensor/tensor/slinalg.py", line 298, in perform
    outputs[0][0] = scipy.linalg.solve_triangular(
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/pymc_env/lib/python3.11/site-packages/scipy/linalg/_basic.py", line 334, in solve_triangular
    b1 = _asarray_validated(b, check_finite=check_finite)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/pymc_env/lib/python3.11/site-packages/scipy/_lib/_util.py", line 321, in _asarray_validated
    a = toarray(a)
        ^^^^^^^^^^
  File "/opt/anaconda3/envs/pymc_env/lib/python3.11/site-packages/numpy/lib/function_base.py", line 630, in asarray_chkfinite
    raise ValueError(
ValueError: array must not contain infs or NaNs
Apply node that caused the error: SolveTriangular{trans=0, unit_diagonal=False, lower=True, check_finite=True, b_ndim=2}([[1. 0. 0. ... 0. 0. 1.]], Sub.0)
Toposort index: 13
Inputs types: [TensorType(float64, shape=(4, 4)), TensorType(float64, shape=(4, 3))]
Inputs shapes: [(4, 4), (4, 3)]
Inputs strides: [(8, 32), (24, 8)]
Inputs values: ['not shown', 'not shown']
Outputs clients: [[Transpose{axes=[1, 0]}(SolveTriangular{trans=0, unit_diagonal=False, lower=True, check_finite=True, b_ndim=2}.0), Dot22(Transpose{axes=[1, 0]}.0, SolveTriangular{trans=0, unit_diagonal=False, lower=True, check_finite=True, b_ndim=2}.0)]]

Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
  File "/opt/anaconda3/envs/pymc_env/lib/python3.11/functools.py", line 909, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/opt/anaconda3/envs/pymc_env/lib/python3.11/site-packages/pymc/distributions/distribution.py", line 826, in custom_dist_logp
    return logp(values[0], *dist_params)
  File "/Users/maximescali/Desktop/Research/MScthesis/sampling/modelsv1/pymc_inv.py", line 14, in logp
    return -pt.trace(QX)**2+pm.logp(normal,QX)
  File "/opt/anaconda3/envs/pymc_env/lib/python3.11/site-packages/pymc/logprob/basic.py", line 211, in logp
    return _logprob_helper(rv, value, **kwargs)
  File "/opt/anaconda3/envs/pymc_env/lib/python3.11/site-packages/pymc/logprob/abstract.py", line 68, in _logprob_helper
    logprob = _logprob(rv.owner.op, values, *rv.owner.inputs, **kwargs)
  File "/opt/anaconda3/envs/pymc_env/lib/python3.11/functools.py", line 909, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/opt/anaconda3/envs/pymc_env/lib/python3.11/site-packages/pymc/distributions/distribution.py", line 213, in logp
    return class_logp(value, *dist_params)
  File "/opt/anaconda3/envs/pymc_env/lib/python3.11/site-packages/pymc/distributions/multivariate.py", line 1876, in logp
    right_quaddist = solve_lower(rowchol, delta)

HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.

I suspect it coud be because the matrix $X^TX$ might not be positive definite if X is the null matrix (at initialisation). However I do not have a clear intuition of what is going on from the error message since my logp takes as input and outputs a 'pytensor.tensor.variable.TensorVariable' (it does not contain a value). When I set

exception_verbosity=high

as advised, I get the Pytensor graph of the problematic node but it still does not show me the value of the nodes. I also tried to replace

step = pm.NUTS()

by

_,step = pm.init_nuts(initvals={"Q":np.eye(4,3)})

by it does not change the error.

I would appreciate any help !

Upvotes: 0

Views: 49

Answers (0)

Related Questions