From 9f52e3df6d1d0959189c547ff31f951d99863d00 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 27 Feb 2025 11:46:43 +0100 Subject: [PATCH 1/6] Fix expand change_dist_size of SymbolicRandomVariables with size=None --- pymc/distributions/distribution.py | 2 +- tests/distributions/test_distribution.py | 26 ++++++++++++++++++++++-- tests/distributions/test_shape_utils.py | 11 ++++++++++ 3 files changed, 36 insertions(+), 3 deletions(-) diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index 5ec5df4671..61dc731b93 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -400,7 +400,7 @@ def change_symbolic_rv_size(op: SymbolicRandomVariable, rv, new_size, expand) -> params = op.dist_params(rv.owner) - if expand: + if expand and not rv_size_is_none(size): new_size = tuple(new_size) + tuple(size) return op.rv_op(*params, size=new_size) diff --git a/tests/distributions/test_distribution.py b/tests/distributions/test_distribution.py index df97905073..d7e2bbd0a1 100644 --- a/tests/distributions/test_distribution.py +++ b/tests/distributions/test_distribution.py @@ -23,7 +23,8 @@ import scipy.stats as st from pytensor import shared -from pytensor.tensor import TensorVariable +from pytensor.tensor import NoneConst, TensorVariable +from pytensor.tensor.random.utils import normalize_size_param import pymc as pm @@ -43,7 +44,7 @@ ) from pymc.distributions.shape_utils import change_dist_size from pymc.logprob.basic import conditional_logp, logp -from pymc.pytensorf import compile +from pymc.pytensorf import compile, normalize_rng_param from pymc.testing import ( BaseTestDistributionRandom, I, @@ -210,6 +211,27 @@ def test_recreate_with_different_rng_inputs(self): new_next_rng, new_x = x.owner.op(*inputs) assert op.update(new_x.owner) == {new_rng: new_next_rng} + def test_change_dist_size_none(self): + class TestRV(SymbolicRandomVariable): + extended_signature = "[rng],[size]->[rng],(n)" + + @classmethod + def rv_op(cls, size=None, rng=None): + rng = normalize_rng_param(rng) + size = normalize_size_param(size) + next_rng, draws = Normal.dist(size=size, rng=rng).owner.outputs + return cls(inputs=[rng, size], outputs=[next_rng, draws])(rng, size) + + size = NoneConst + rv = TestRV.rv_op(size=size) + assert rv.type.shape == () + + resized_rv = change_dist_size(rv, new_size=5) + assert resized_rv.type.shape == (5,) + + resized_rv = change_dist_size(rv, new_size=5, expand=True) + assert resized_rv.type.shape == (5,) + def test_tag_future_warning_dist(): # Test no unexpected warnings diff --git a/tests/distributions/test_shape_utils.py b/tests/distributions/test_shape_utils.py index 8579bfd8e1..d0f3f1b432 100644 --- a/tests/distributions/test_shape_utils.py +++ b/tests/distributions/test_shape_utils.py @@ -427,6 +427,17 @@ def test_change_rv_size(): assert tuple(rv_newer.shape.eval()) == (2,) +def test_change_rv_size_expand_none_size(): + x = pt.random.normal() + size = x.owner.op.size_param(x.owner) + assert rv_size_is_none(size) + new_x = change_dist_size(x, new_size=(2,), expand=True) + new_size = new_x.owner.op.size_param(new_x.owner) + assert not rv_size_is_none(new_size) + assert new_size.data == [2] + assert new_x.type.shape == (2,) + + def test_change_rv_size_default_update(): rng = pytensor.shared(np.random.default_rng(0)) x = normal(rng=rng) From 29e10e2d38e0997ebbf9e5aa73352c6214f21771 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 4 Mar 2025 12:15:30 +0100 Subject: [PATCH 2/6] Bumpy PyTensor dependency --- conda-envs/environment-dev.yml | 2 +- conda-envs/environment-docs.yml | 2 +- conda-envs/environment-jax.yml | 2 +- conda-envs/environment-test.yml | 2 +- conda-envs/windows-environment-dev.yml | 2 +- conda-envs/windows-environment-test.yml | 2 +- requirements-dev.txt | 2 +- requirements.txt | 2 +- 8 files changed, 8 insertions(+), 8 deletions(-) diff --git a/conda-envs/environment-dev.yml b/conda-envs/environment-dev.yml index b2e8f17849..857f7ef513 100644 --- a/conda-envs/environment-dev.yml +++ b/conda-envs/environment-dev.yml @@ -12,7 +12,7 @@ dependencies: - numpy>=1.25.0 - pandas>=0.24.0 - pip -- pytensor>=2.28.1,<2.29 +- pytensor>=2.28.3,<2.29 - python-graphviz - networkx - scipy>=1.4.1 diff --git a/conda-envs/environment-docs.yml b/conda-envs/environment-docs.yml index 823f6563f6..267c77cf2a 100644 --- a/conda-envs/environment-docs.yml +++ b/conda-envs/environment-docs.yml @@ -11,7 +11,7 @@ dependencies: - numpy>=1.25.0 - pandas>=0.24.0 - pip -- pytensor>=2.28.1,<2.29 +- pytensor>=2.28.3,<2.29 - python-graphviz - rich>=13.7.1 - scipy>=1.4.1 diff --git a/conda-envs/environment-jax.yml b/conda-envs/environment-jax.yml index ea630e806f..4a16955faf 100644 --- a/conda-envs/environment-jax.yml +++ b/conda-envs/environment-jax.yml @@ -20,7 +20,7 @@ dependencies: - numpyro>=0.8.0 - pandas>=0.24.0 - pip -- pytensor>=2.28.1,<2.29 +- pytensor>=2.28.3,<2.29 - python-graphviz - networkx - rich>=13.7.1 diff --git a/conda-envs/environment-test.yml b/conda-envs/environment-test.yml index 37729d4d1a..dc732cfc49 100644 --- a/conda-envs/environment-test.yml +++ b/conda-envs/environment-test.yml @@ -14,7 +14,7 @@ dependencies: - pandas>=0.24.0 - pip - polyagamma -- pytensor>=2.28.1,<2.29 +- pytensor>=2.28.3,<2.29 - python-graphviz - networkx - rich>=13.7.1 diff --git a/conda-envs/windows-environment-dev.yml b/conda-envs/windows-environment-dev.yml index 1c4d8f04d5..989a0558c4 100644 --- a/conda-envs/windows-environment-dev.yml +++ b/conda-envs/windows-environment-dev.yml @@ -12,7 +12,7 @@ dependencies: - numpy>=1.25.0 - pandas>=0.24.0 - pip -- pytensor>=2.28.1,<2.29 +- pytensor>=2.28.3,<2.29 - python-graphviz - networkx - rich>=13.7.1 diff --git a/conda-envs/windows-environment-test.yml b/conda-envs/windows-environment-test.yml index 63f5f8ee4e..9435b068b0 100644 --- a/conda-envs/windows-environment-test.yml +++ b/conda-envs/windows-environment-test.yml @@ -15,7 +15,7 @@ dependencies: - pandas>=0.24.0 - pip - polyagamma -- pytensor>=2.28.1,<2.29 +- pytensor>=2.28.3,<2.29 - python-graphviz - networkx - rich>=13.7.1 diff --git a/requirements-dev.txt b/requirements-dev.txt index c868dbd52e..becf5c7fcc 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -16,7 +16,7 @@ numpydoc pandas>=0.24.0 polyagamma pre-commit>=2.8.0 -pytensor>=2.28.1,<2.29 +pytensor>=2.28.3,<2.29 pytest-cov>=2.5 pytest>=3.0 rich>=13.7.1 diff --git a/requirements.txt b/requirements.txt index fb0c131ce7..c0a6151122 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ cachetools>=4.2.1 cloudpickle numpy>=1.25.0 pandas>=0.24.0 -pytensor>=2.28.1,<2.29 +pytensor>=2.28.2,<2.29 rich>=13.7.1 scipy>=1.4.1 threadpoolctl>=3.1.0,<4.0.0 From 9e05c4c923a36bfc0eacb9a2c6d00d8a2ecbbe9e Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 4 Mar 2025 12:17:06 +0100 Subject: [PATCH 3/6] Compatible `normalize_axis_tuple` import --- pymc/distributions/transforms.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/pymc/distributions/transforms.py b/pymc/distributions/transforms.py index c8ca8d0554..be0df56541 100644 --- a/pymc/distributions/transforms.py +++ b/pymc/distributions/transforms.py @@ -18,11 +18,8 @@ import numpy as np import pytensor.tensor as pt - -# ignore mypy error because it somehow considers that -# "numpy.core.numeric has no attribute normalize_axis_tuple" -from numpy.core.numeric import normalize_axis_tuple # type: ignore[attr-defined] from pytensor.graph import Op +from pytensor.npy_2_compat import normalize_axis_tuple from pytensor.tensor import TensorVariable from pymc.logprob.transforms import ( From 77e01993980f04475e69ed78a62d45c946bc1663 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 19 Feb 2025 10:19:53 +0100 Subject: [PATCH 4/6] Make more distributions symbolic so they work in different backends --- .github/workflows/tests.yml | 18 +- ...l => environment-alternative-backends.yml} | 2 + pymc/distributions/continuous.py | 29 ++- pymc/distributions/multivariate.py | 210 ++++++++---------- pymc/distributions/shape_utils.py | 3 +- tests/distributions/test_multivariate.py | 46 ++-- .../test_random_alternative_backends.py | 70 ++++++ tests/sampling/test_jax.py | 27 +-- 8 files changed, 216 insertions(+), 189 deletions(-) rename conda-envs/{environment-jax.yml => environment-alternative-backends.yml} (96%) create mode 100644 tests/distributions/test_random_alternative_backends.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 97e50bef2a..a1460909c5 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -281,16 +281,20 @@ jobs: name: ${{ matrix.os }} ${{ matrix.floatx }} fail_ci_if_error: false - external_samplers: + alternative_backends: needs: changes if: ${{ needs.changes.outputs.changes == 'true' }} strategy: matrix: os: [ubuntu-20.04] floatx: [float64] - python-version: ["3.13"] + python-version: ["3.12"] test-subset: - - tests/sampling/test_jax.py tests/sampling/test_mcmc_external.py + - | + tests/distributions/test_random_alternative_backends.py + tests/sampling/test_jax.py + tests/sampling/test_mcmc_external.py + fail-fast: false runs-on: ${{ matrix.os }} env: @@ -305,7 +309,7 @@ jobs: persist-credentials: false - uses: mamba-org/setup-micromamba@v2 with: - environment-file: conda-envs/environment-jax.yml + environment-file: conda-envs/environment-alternative-backends.yml create-args: >- python=${{matrix.python-version}} environment-name: pymc-test @@ -324,7 +328,7 @@ jobs: with: token: ${{ secrets.CODECOV_TOKEN }} # use token for more robust uploads env_vars: TEST_SUBSET - name: JAX tests - ${{ matrix.os }} ${{ matrix.floatx }} + name: Alternative backend tests - ${{ matrix.os }} ${{ matrix.floatx }} fail_ci_if_error: false float32: @@ -378,13 +382,13 @@ jobs: all_tests: if: ${{ always() }} runs-on: ubuntu-latest - needs: [ changes, ubuntu, windows, macos, external_samplers, float32 ] + needs: [ changes, ubuntu, windows, macos, alternative_backends, float32 ] steps: - name: Check build matrix status if: ${{ needs.changes.outputs.changes == 'true' && ( needs.ubuntu.result != 'success' || needs.windows.result != 'success' || needs.macos.result != 'success' || - needs.external_samplers.result != 'success' || + needs.alternative_backends.result != 'success' || needs.float32.result != 'success' ) }} run: exit 1 diff --git a/conda-envs/environment-jax.yml b/conda-envs/environment-alternative-backends.yml similarity index 96% rename from conda-envs/environment-jax.yml rename to conda-envs/environment-alternative-backends.yml index 4a16955faf..b51ee3308d 100644 --- a/conda-envs/environment-jax.yml +++ b/conda-envs/environment-alternative-backends.yml @@ -10,6 +10,8 @@ dependencies: - cachetools>=4.2.1 - cloudpickle - zarr>=2.5.0,<3 +- numba +- nutpie >= 0.13.4 # Jaxlib version must not be greater than jax version! - blackjax>=1.2.2 - jax>=0.4.28 diff --git a/pymc/distributions/continuous.py b/pymc/distributions/continuous.py index 082be31d5c..13228794d3 100644 --- a/pymc/distributions/continuous.py +++ b/pymc/distributions/continuous.py @@ -2595,23 +2595,27 @@ def dist(cls, nu, **kwargs): return Gamma.dist(alpha=nu / 2, beta=1 / 2, **kwargs) -class WeibullBetaRV(RandomVariable): +class WeibullBetaRV(SymbolicRandomVariable): name = "weibull" - signature = "(),()->()" - dtype = "floatX" + extended_signature = "[rng],[size],(),()->[rng],()" _print_name = ("Weibull", "\\operatorname{Weibull}") - def __call__(self, alpha, beta, size=None, **kwargs): - return super().__call__(alpha, beta, size=size, **kwargs) - @classmethod - def rng_fn(cls, rng, alpha, beta, size) -> np.ndarray: - if size is None: - size = np.broadcast_shapes(alpha.shape, beta.shape) - return np.asarray(beta * rng.weibull(alpha, size=size)) + def rv_op(cls, alpha, beta, *, rng=None, size=None) -> np.ndarray: + alpha = pt.as_tensor(alpha) + beta = pt.as_tensor(beta) + rng = normalize_rng_param(rng) + size = normalize_size_param(size) + if rv_size_is_none(size): + size = implicit_size_from_params(alpha, beta, ndims_params=cls.ndims_params) -weibull_beta = WeibullBetaRV() + next_rng, raw_weibull = pt.random.weibull(alpha, size=size, rng=rng).owner.outputs + draws = beta * raw_weibull + return cls( + inputs=[rng, size, alpha, beta], + outputs=[next_rng, draws], + )(rng, size, alpha, beta) class Weibull(PositiveContinuous): @@ -2660,7 +2664,8 @@ class Weibull(PositiveContinuous): Scale parameter (beta > 0). """ - rv_op = weibull_beta + rv_type = WeibullBetaRV + rv_op = WeibullBetaRV.rv_op @classmethod def dist(cls, alpha, beta, *args, **kwargs): diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 9ff56027a8..24fa8bae56 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -37,10 +37,10 @@ from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.linalg import cholesky, det, eigh, solve_triangular, trace from pytensor.tensor.linalg import inv as matrix_inverse +from pytensor.tensor.random import chisquare from pytensor.tensor.random.basic import MvNormalRV, dirichlet, multinomial, multivariate_normal from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.random.utils import ( - broadcast_params, normalize_size_param, ) from pytensor.tensor.type import TensorType @@ -365,33 +365,37 @@ def mv_normal_to_precision_mv_normal(fgraph, node): ) -class MvStudentTRV(RandomVariable): +class MvStudentTRV(SymbolicRandomVariable): + r"""A specialized multivariate normal random variable defined in terms of precision. + + This class is introduced during specialization logprob rewrites, and not meant to be used directly. + """ + name = "multivariate_studentt" - signature = "(),(n),(n,n)->(n)" - dtype = "floatX" + extended_signature = "[rng],[size],(),(n),(n,n)->[rng],(n)" _print_name = ("MvStudentT", "\\operatorname{MvStudentT}") @classmethod - def rng_fn(cls, rng, nu, mu, cov, size): - if size is None: - # When size is implicit, we need to broadcast parameters correctly, - # so that the MvNormal draws and the chisquare draws have the same number of batch dimensions. - # nu broadcasts mu and cov - if np.ndim(nu) > max(mu.ndim - 1, cov.ndim - 2): - _, mu, cov = broadcast_params((nu, mu, cov), ndims_params=cls.ndims_params) - # nu is broadcasted by either mu or cov - elif np.ndim(nu) < max(mu.ndim - 1, cov.ndim - 2): - nu, _, _ = broadcast_params((nu, mu, cov), ndims_params=cls.ndims_params) - - mv_samples = multivariate_normal.rng_fn(rng=rng, mean=np.zeros_like(mu), cov=cov, size=size) - - # Take chi2 draws and add an axis of length 1 to the right for correct broadcasting below - chi2_samples = np.sqrt(rng.chisquare(nu, size=size) / nu)[..., None] + def rv_op(cls, nu, mean, scale, *, rng=None, size=None): + nu = pt.as_tensor(nu) + mean = pt.as_tensor(mean) + scale = pt.as_tensor(scale) + rng = normalize_rng_param(rng) + size = normalize_size_param(size) - return (mv_samples / chi2_samples) + mu + if rv_size_is_none(size): + size = implicit_size_from_params(nu, mean, scale, ndims_params=cls.ndims_params) + next_rng, mv_draws = multivariate_normal( + mean.zeros_like(), scale, size=size, rng=rng + ).owner.outputs + next_rng, chi2_draws = chisquare(nu, size=size, rng=next_rng).owner.outputs + draws = mean + (mv_draws / pt.sqrt(chi2_draws / nu)[..., None]) -mv_studentt = MvStudentTRV() + return cls( + inputs=[rng, size, nu, mean, scale], + outputs=[next_rng, draws], + )(rng, size, nu, mean, scale) class MvStudentT(Continuous): @@ -435,7 +439,8 @@ class MvStudentT(Continuous): Whether the cholesky fatcor is given as a lower triangular matrix. """ - rv_op = mv_studentt + rv_type = MvStudentTRV + rv_op = MvStudentTRV.rv_op @classmethod def dist(cls, nu, *, Sigma=None, mu=0, scale=None, tau=None, chol=None, lower=True, **kwargs): @@ -1152,57 +1157,6 @@ def _lkj_normalizing_constant(eta, n): return result -class _LKJCholeskyCovBaseRV(RandomVariable): - name = "_lkjcholeskycovbase" - signature = "(),(),(d)->(n)" - dtype = "floatX" - _print_name = ("_lkjcholeskycovbase", "\\operatorname{_lkjcholeskycovbase}") - - def make_node(self, rng, size, n, eta, D): - n = pt.as_tensor_variable(n) - if not all(n.type.broadcastable): - raise ValueError("n must be a scalar.") - - eta = pt.as_tensor_variable(eta) - if not all(eta.type.broadcastable): - raise ValueError("eta must be a scalar.") - - D = pt.as_tensor_variable(D) - - return super().make_node(rng, size, n, eta, D) - - def _supp_shape_from_params(self, dist_params, param_shapes): - n = dist_params[0].squeeze() - return ((n * (n + 1)) // 2,) - - def rng_fn(self, rng, n, eta, D, size): - # We flatten the size to make operations easier, and then rebuild it - if size is None: - size = D.shape[:-1] - flat_size = np.prod(size).astype(int) - - n = n.squeeze() - eta = eta.squeeze() - - C = LKJCorrRV._random_corr_matrix(rng=rng, n=n, eta=eta, flat_size=flat_size) - D = D.reshape(flat_size, n) - C *= D[..., :, np.newaxis] * D[..., np.newaxis, :] - - tril_idx = np.tril_indices(n, k=0) - samples = np.linalg.cholesky(C)[..., tril_idx[0], tril_idx[1]] - - if size is None: - samples = samples[0] - else: - dist_shape = (n * (n + 1)) // 2 - samples = np.reshape(samples, (*size, dist_shape)) - - return samples - - -_ljk_cholesky_cov_base = _LKJCholeskyCovBaseRV() - - # _LKJCholeskyCovBaseRV requires a properly shaped `D`, which means the variable can't # be safely resized. Because of this, we add the thin SymbolicRandomVariable wrapper class _LKJCholeskyCovRV(SymbolicRandomVariable): @@ -1223,21 +1177,40 @@ def rv_op(cls, n, eta, sd_dist, *, size=None): # for each diagonal element. # Since `eta` and `n` are forced to be scalars we don't need to worry about # implied batched dimensions from those for the time being. + if rv_size_is_none(size): - size = sd_dist.shape[:-1] + sd_dist_size = sd_dist.shape[:-1] + else: + sd_dist_size = size - shape = (*size, n) if sd_dist.owner.op.ndim_supp == 0: - sd_dist = change_dist_size(sd_dist, shape) + sd_dist = change_dist_size(sd_dist, (*sd_dist_size, n)) else: # The support shape must be `n` but we have no way of controlling it - sd_dist = change_dist_size(sd_dist, shape[:-1]) + sd_dist = change_dist_size(sd_dist, sd_dist_size) + + D = sd_dist.type(name="D") # Make sd_dist opaque to OpFromGraph + size = D.shape[:-1] + + # We flatten the size to make operations easier, and then rebuild it + flat_size = pt.prod(size, dtype="int64") + + next_rng, C = LKJCorrRV._random_corr_matrix(rng=rng, n=n, eta=eta, flat_size=flat_size) + D_matrix = D.reshape((flat_size, n)) + C *= D_matrix[..., :, None] * D_matrix[..., None, :] - next_rng, lkjcov = _ljk_cholesky_cov_base(n, eta, sd_dist, rng=rng).owner.outputs + tril_idx = pt.tril_indices(n, k=0) + samples = pt.linalg.cholesky(C)[..., tril_idx[0], tril_idx[1]] + + if rv_size_is_none(size): + samples = samples[0] + else: + dist_shape = (n * (n + 1)) // 2 + samples = pt.reshape(samples, (*size, dist_shape)) return _LKJCholeskyCovRV( - inputs=[rng, n, eta, sd_dist], - outputs=[next_rng, lkjcov], + inputs=[rng, n, eta, D], + outputs=[next_rng, samples], )(rng, n, eta, sd_dist) def update(self, node): @@ -1508,10 +1481,9 @@ def helper_deterministics(cls, n, packed_chol): return chol, corr, stds -class LKJCorrRV(RandomVariable): +class LKJCorrRV(SymbolicRandomVariable): name = "lkjcorr" - signature = "(),()->(n)" - dtype = "floatX" + extended_signature = "[rng],[size],(),()->[rng],(n)" _print_name = ("LKJCorrRV", "\\operatorname{LKJCorrRV}") def make_node(self, rng, size, n, eta): @@ -1525,55 +1497,66 @@ def make_node(self, rng, size, n, eta): return super().make_node(rng, size, n, eta) - def _supp_shape_from_params(self, dist_params, **kwargs): - n = dist_params[0].squeeze() - dist_shape = ((n * (n - 1)) // 2,) - return dist_shape - @classmethod - def rng_fn(cls, rng, n, eta, size): + def rv_op(cls, n: int, eta, *, rng=None, size=None): # We flatten the size to make operations easier, and then rebuild it - if size is None: + n = pt.as_tensor(n, ndim=0, dtype=int) + eta = pt.as_tensor(eta, ndim=0) + rng = normalize_rng_param(rng) + size = normalize_size_param(size) + + if rv_size_is_none(size): flat_size = 1 else: - flat_size = np.prod(size).astype(int) + flat_size = pt.prod(size, dtype="int64") - n = n.squeeze() - eta = eta.squeeze() - C = cls._random_corr_matrix(rng=rng, n=n, eta=eta, flat_size=flat_size) + next_rng, C = cls._random_corr_matrix(rng=rng, n=n, eta=eta, flat_size=flat_size) - triu_idx = np.triu_indices(n, k=1) + triu_idx = pt.triu_indices(n, k=1) samples = C[..., triu_idx[0], triu_idx[1]] - if size is None: + if rv_size_is_none(size): samples = samples[0] else: dist_shape = (n * (n - 1)) // 2 - samples = np.reshape(samples, (*size, dist_shape)) + samples = pt.reshape(samples, (*size, dist_shape)) + + return cls( + inputs=[rng, size, n, eta], + outputs=[next_rng, samples], + )(rng, size, n, eta) + return samples @classmethod - def _random_corr_matrix(cls, rng, n, eta, flat_size): + def _random_corr_matrix( + cls, rng: Variable, n: int, eta: TensorVariable, flat_size: TensorVariable + ) -> tuple[Variable, TensorVariable]: # original implementation in R see: # https://github.com/rmcelreath/rethinking/blob/master/R/distributions.r beta = eta - 1.0 + n / 2.0 - r12 = 2.0 * stats.beta.rvs(a=beta, b=beta, size=flat_size, random_state=rng) - 1.0 - P = np.full((flat_size, n, n), np.eye(n)) - P[..., 0, 1] = r12 - P[..., 1, 1] = np.sqrt(1.0 - r12**2) + next_rng, beta_rvs = pt.random.beta( + alpha=beta, beta=beta, size=flat_size, rng=rng + ).owner.outputs + r12 = 2.0 * beta_rvs - 1.0 + P = pt.full((flat_size, n, n), pt.eye(n)) + P = P[..., 0, 1].set(r12) + P = P[..., 1, 1].set(pt.sqrt(1.0 - r12**2)) + n = get_underlying_scalar_constant_value(n) for mp1 in range(2, n): beta -= 0.5 - y = stats.beta.rvs(a=mp1 / 2.0, b=beta, size=flat_size, random_state=rng) - z = stats.norm.rvs(loc=0, scale=1, size=(flat_size, mp1), random_state=rng) - z = z / np.sqrt(np.einsum("ij,ij->i", z, z))[..., np.newaxis] - P[..., 0:mp1, mp1] = np.sqrt(y[..., np.newaxis]) * z - P[..., mp1, mp1] = np.sqrt(1.0 - y) - C = np.einsum("...ji,...jk->...ik", P, P) - return C - - -lkjcorr = LKJCorrRV() + next_rng, y = pt.random.beta( + alpha=mp1 / 2.0, beta=beta, size=flat_size, rng=next_rng + ).owner.outputs + next_rng, z = pt.random.normal( + loc=0, scale=1, size=(flat_size, mp1), rng=next_rng + ).owner.outputs + z = z / pt.sqrt(pt.einsum("ij,ij->i", z, z.copy()))[..., np.newaxis] + P = P[..., 0:mp1, mp1].set(pt.sqrt(y[..., np.newaxis]) * z) + P = P[..., mp1, mp1].set(pt.sqrt(1.0 - y)) + C = pt.einsum("...ji,...jk->...ik", P, P.copy()) + return next_rng, C class MultivariateIntervalTransform(Interval): @@ -1585,7 +1568,8 @@ def log_jac_det(self, *args): # Returns list of upper triangular values class _LKJCorr(BoundedContinuous): - rv_op = lkjcorr + rv_type = LKJCorrRV + rv_op = LKJCorrRV.rv_op @classmethod def dist(cls, n, eta, **kwargs): diff --git a/pymc/distributions/shape_utils.py b/pymc/distributions/shape_utils.py index 98c743b70e..6f54aba2d1 100644 --- a/pymc/distributions/shape_utils.py +++ b/pymc/distributions/shape_utils.py @@ -468,5 +468,6 @@ def implicit_size_from_params( pt.broadcast_shape( *batch_shapes, arrays_are_shapes=True, - ) + ), + dtype="int64", # In case it's empty, as_tensor will default to floatX ) diff --git a/tests/distributions/test_multivariate.py b/tests/distributions/test_multivariate.py index 2694230c32..b184e04afa 100644 --- a/tests/distributions/test_multivariate.py +++ b/tests/distributions/test_multivariate.py @@ -1310,32 +1310,15 @@ def test_kronecker_normal_support_point(self, mu, covs, size, expected): [ (3, 1, None, np.zeros(3)), (5, 1, None, np.zeros(10)), - pytest.param( - 3, - 1, - 1, - np.zeros((1, 3)), - marks=pytest.mark.xfail( - raises=NotImplementedError, - reason="LKJCorr logp is only implemented for vector values (ndim=1)", - ), - ), - pytest.param( - 5, - 1, - (2, 3), - np.zeros((2, 3, 10)), - marks=pytest.mark.xfail( - raises=NotImplementedError, - reason="LKJCorr logp is only implemented for vector values (ndim=1)", - ), - ), + pytest.param(3, 1, 1, np.zeros((1, 3))), + pytest.param(5, 1, (2, 3), np.zeros((2, 3, 10))), ], ) def test_lkjcorr_support_point(self, n, eta, size, expected): with pm.Model() as model: pm.LKJCorr("x", n=n, eta=eta, size=size, return_matrix=False) - assert_support_point_is_expected(model, expected) + # LKJCorr logp is only implemented for vector values (size=None) + assert_support_point_is_expected(model, expected, check_finite_logp=size is None) @pytest.mark.parametrize( "n, eta, size, expected", @@ -2190,15 +2173,18 @@ def ref_rand(size, n, eta): beta = eta - 1 + n / 2 return (st.beta.rvs(size=(size, shape), a=beta, b=beta) - 0.5) * 2 - continuous_random_tester( - _LKJCorr, - { - "n": Domain([2, 10, 50], edges=(None, None)), - "eta": Domain([1.0, 10.0, 100.0], edges=(None, None)), - }, - ref_rand=ref_rand, - size=1000, - ) + # If passed as a domain, continuous_random_tester would make `n` a shared variable + # But this RV needs it to be constant in order to define the inner graph + for n in (2, 10, 50): + continuous_random_tester( + _LKJCorr, + { + "eta": Domain([1.0, 10.0, 100.0], edges=(None, None)), + }, + extra_args={"n": n}, + ref_rand=ft.partial(ref_rand, n=n), + size=1000, + ) @pytest.mark.parametrize( diff --git a/tests/distributions/test_random_alternative_backends.py b/tests/distributions/test_random_alternative_backends.py new file mode 100644 index 0000000000..98214cdae9 --- /dev/null +++ b/tests/distributions/test_random_alternative_backends.py @@ -0,0 +1,70 @@ +# Copyright 2025 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from contextlib import nullcontext + +import numpy as np +import pytest + +import pymc as pm + +from pymc import DirichletMultinomial, MvStudentT +from pymc.model.transform.optimization import freeze_dims_and_data + + +@pytest.fixture(params=["FAST_RUN", "JAX", "NUMBA"]) +def mode(request): + mode_param = request.param + if mode_param != "FAST_RUN": + pytest.importorskip(mode_param.lower()) + return mode_param + + +def test_dirichlet_multinomial(mode): + """Test we can draw from a DM in the JAX backend if the shape is constant.""" + dm = DirichletMultinomial.dist(n=5, a=np.eye(3) * 1e6 + 0.01) + dm_draws = pm.draw(dm, mode=mode) + np.testing.assert_equal(dm_draws, np.eye(3) * 5) + + +def test_dirichlet_multinomial_dims(mode): + """Test we can draw from a DM with a shape defined by dims in the JAX backend, + after freezing those dims. + """ + with pm.Model(coords={"trial": range(3), "item": range(3)}) as m: + dm = DirichletMultinomial("dm", n=5, a=np.eye(3) * 1e6 + 0.01, dims=("trial", "item")) + + # JAX does not allow us to JIT a function with dynamic shape + expected_ctxt = pytest.raises(TypeError) if mode == "JAX" else nullcontext() + with expected_ctxt: + pm.draw(dm, mode=mode) + + # Should be fine after freezing the dims that specify the shape + frozen_dm = freeze_dims_and_data(m)["dm"] + dm_draws = pm.draw(frozen_dm, mode=mode) + np.testing.assert_equal(dm_draws, np.eye(3) * 5) + + +def test_mvstudentt(mode): + mvt = MvStudentT.dist(nu=100, mu=[1, 2, 3], scale=np.eye(3) * [0.01, 1, 100], shape=(10_000, 3)) + draws = pm.draw(mvt, mode=mode) + np.testing.assert_allclose(draws.mean(0), [1, 2, 3], rtol=0.1) + np.testing.assert_allclose(draws.std(0), np.sqrt([0.01, 1, 100]), rtol=0.1) + + +def test_repeated_arguments(mode): + # Regression test for a failure in Numba mode when a RV had repeated arguments + v = 0.5 * 1e5 + x = pm.Beta.dist(v, v) + x_draw = pm.draw(x, mode=mode) + np.testing.assert_allclose(x_draw, 0.5, rtol=0.01) diff --git a/tests/sampling/test_jax.py b/tests/sampling/test_jax.py index ddec60e539..0205c4ebf7 100644 --- a/tests/sampling/test_jax.py +++ b/tests/sampling/test_jax.py @@ -34,8 +34,7 @@ import pymc as pm from pymc import ImputationWarning -from pymc.distributions.multivariate import DirichletMultinomial, PosDefMatrix -from pymc.model.transform.optimization import freeze_dims_and_data +from pymc.distributions.multivariate import PosDefMatrix from pymc.sampling.jax import ( _get_batched_jittered_initial_points, _get_log_likelihood, @@ -505,27 +504,3 @@ def test_convergence_warnings(caplog, nuts_sampler): [record] = caplog.records assert re.match(r"There were \d+ divergences after tuning", record.message) - - -def test_dirichlet_multinomial(): - """Test we can draw from a DM in the JAX backend if the shape is constant.""" - dm = DirichletMultinomial.dist(n=5, a=np.eye(3) * 1e6 + 0.01) - dm_draws = pm.draw(dm, mode="JAX") - np.testing.assert_equal(dm_draws, np.eye(3) * 5) - - -def test_dirichlet_multinomial_dims(): - """Test we can draw from a DM with a shape defined by dims in the JAX backend, - after freezing those dims. - """ - with pm.Model(coords={"trial": range(3), "item": range(3)}) as m: - dm = DirichletMultinomial("dm", n=5, a=np.eye(3) * 1e6 + 0.01, dims=("trial", "item")) - - # JAX does not allow us to JIT a function with dynamic shape - with pytest.raises(TypeError): - pm.draw(dm, mode="JAX") - - # Should be fine after freezing the dims that specify the shape - frozen_dm = freeze_dims_and_data(m)["dm"] - dm_draws = pm.draw(frozen_dm, mode="JAX") - np.testing.assert_equal(dm_draws, np.eye(3) * 5) From 81aea4aa178e979edacb730740871eb9b2c9967f Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 27 Feb 2025 13:06:57 +0100 Subject: [PATCH 5/6] Allow forwarding of MvNormal method to SymbolicRandomVariables --- pymc/distributions/distribution.py | 10 +++++- pymc/distributions/multivariate.py | 41 ++++++++++++++++++------ tests/distributions/test_multivariate.py | 20 ++++++++++++ 3 files changed, 60 insertions(+), 11 deletions(-) diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index 61dc731b93..6e4e9ba377 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -383,6 +383,14 @@ def batch_ndim(self, node: Apply) -> int: out_ndim = max(getattr(out.type, "ndim", 0) for out in node.outputs) return out_ndim - self.ndim_supp + def rebuild_rv(self, *args, **kwargs): + """Rebuild the RandomVariable with new inputs.""" + if not hasattr(self, "rv_op"): + raise NotImplementedError( + f"SymbolicRandomVariable {self} without `rv_op` method cannot be rebuilt automatically." + ) + return self.rv_op(*args, **kwargs) + @_change_dist_size.register(SymbolicRandomVariable) def change_symbolic_rv_size(op: SymbolicRandomVariable, rv, new_size, expand) -> TensorVariable: @@ -403,7 +411,7 @@ def change_symbolic_rv_size(op: SymbolicRandomVariable, rv, new_size, expand) -> if expand and not rv_size_is_none(size): new_size = tuple(new_size) + tuple(size) - return op.rv_op(*params, size=new_size) + return op.rebuild_rv(*params, size=new_size) class Distribution(metaclass=DistributionMeta): diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 24fa8bae56..04ff858e34 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -302,7 +302,19 @@ def logp(value, mu, cov): ) -class PrecisionMvNormalRV(SymbolicRandomVariable): +class SymbolicMVNormalUsedInternally(SymbolicRandomVariable): + """Helper subclass that handles the forwarding / caching of method to `MvNormal` used internally.""" + + def __init__(self, *args, method: str, **kwargs): + super().__init__(*args, **kwargs) + self.method = method + + def rebuild_rv(self, *args, **kwargs): + # rv_op is a classmethod, so it doesn't have access to the instance method + return self.rv_op(*args, method=self.method, **kwargs) + + +class PrecisionMvNormalRV(SymbolicMVNormalUsedInternally): r"""A specialized multivariate normal random variable defined in terms of precision. This class is introduced during specialization logprob rewrites, and not meant to be used directly. @@ -313,14 +325,17 @@ class PrecisionMvNormalRV(SymbolicRandomVariable): _print_name = ("PrecisionMultivariateNormal", "\\operatorname{PrecisionMultivariateNormal}") @classmethod - def rv_op(cls, mean, tau, *, rng=None, size=None): + def rv_op(cls, mean, tau, *, method: str = "cholesky", rng=None, size=None): rng = normalize_rng_param(rng) size = normalize_size_param(size) cov = pt.linalg.inv(tau) - next_rng, draws = multivariate_normal(mean, cov, size=size, rng=rng).owner.outputs + next_rng, draws = multivariate_normal( + mean, cov, size=size, rng=rng, method=method + ).owner.outputs return cls( inputs=[rng, size, mean, tau], outputs=[next_rng, draws], + method=method, )(rng, size, mean, tau) @@ -354,7 +369,9 @@ def mv_normal_to_precision_mv_normal(fgraph, node): rng, size, mu, cov = node.inputs if cov.owner and cov.owner.op == matrix_inverse: tau = cov.owner.inputs[0] - return PrecisionMvNormalRV.rv_op(mu, tau, size=size, rng=rng).owner.outputs + return PrecisionMvNormalRV.rv_op( + mu, tau, size=size, rng=rng, method=node.op.method + ).owner.outputs return None @@ -365,7 +382,7 @@ def mv_normal_to_precision_mv_normal(fgraph, node): ) -class MvStudentTRV(SymbolicRandomVariable): +class MvStudentTRV(SymbolicMVNormalUsedInternally): r"""A specialized multivariate normal random variable defined in terms of precision. This class is introduced during specialization logprob rewrites, and not meant to be used directly. @@ -376,7 +393,7 @@ class MvStudentTRV(SymbolicRandomVariable): _print_name = ("MvStudentT", "\\operatorname{MvStudentT}") @classmethod - def rv_op(cls, nu, mean, scale, *, rng=None, size=None): + def rv_op(cls, nu, mean, scale, *, method: str = "cholesky", rng=None, size=None): nu = pt.as_tensor(nu) mean = pt.as_tensor(mean) scale = pt.as_tensor(scale) @@ -387,7 +404,7 @@ def rv_op(cls, nu, mean, scale, *, rng=None, size=None): size = implicit_size_from_params(nu, mean, scale, ndims_params=cls.ndims_params) next_rng, mv_draws = multivariate_normal( - mean.zeros_like(), scale, size=size, rng=rng + mean.zeros_like(), scale, size=size, rng=rng, method=method ).owner.outputs next_rng, chi2_draws = chisquare(nu, size=size, rng=next_rng).owner.outputs draws = mean + (mv_draws / pt.sqrt(chi2_draws / nu)[..., None]) @@ -395,6 +412,7 @@ def rv_op(cls, nu, mean, scale, *, rng=None, size=None): return cls( inputs=[rng, size, nu, mean, scale], outputs=[next_rng, draws], + method=method, )(rng, size, nu, mean, scale) @@ -1923,12 +1941,12 @@ def logp(value, mu, rowchol, colchol): return norm - 0.5 * trquaddist - m * half_collogdet - n * half_rowlogdet -class KroneckerNormalRV(SymbolicRandomVariable): +class KroneckerNormalRV(SymbolicMVNormalUsedInternally): ndim_supp = 1 _print_name = ("KroneckerNormal", "\\operatorname{KroneckerNormal}") @classmethod - def rv_op(cls, mu, sigma, *covs, size=None, rng=None): + def rv_op(cls, mu, sigma, *covs, method: str = "cholesky", size=None, rng=None): mu = pt.as_tensor(mu) sigma = pt.as_tensor(sigma) covs = [pt.as_tensor(cov) for cov in covs] @@ -1937,7 +1955,9 @@ def rv_op(cls, mu, sigma, *covs, size=None, rng=None): cov = reduce(pt.linalg.kron, covs) cov = cov + sigma**2 * pt.eye(cov.shape[-2]) - next_rng, draws = multivariate_normal(mean=mu, cov=cov, size=size, rng=rng).owner.outputs + next_rng, draws = multivariate_normal( + mean=mu, cov=cov, size=size, rng=rng, method=method + ).owner.outputs covs_sig = ",".join(f"(a{i},b{i})" for i in range(len(covs))) extended_signature = f"[rng],[size],(m),(),{covs_sig}->[rng],(m)" @@ -1946,6 +1966,7 @@ def rv_op(cls, mu, sigma, *covs, size=None, rng=None): inputs=[rng, size, mu, sigma, *covs], outputs=[next_rng, draws], extended_signature=extended_signature, + method=method, )(rng, size, mu, sigma, *covs) diff --git a/tests/distributions/test_multivariate.py b/tests/distributions/test_multivariate.py index b184e04afa..cb4b8520b9 100644 --- a/tests/distributions/test_multivariate.py +++ b/tests/distributions/test_multivariate.py @@ -2469,6 +2469,26 @@ def test_mvstudentt_mu_convenience(): np.testing.assert_allclose(mu.eval(), np.ones((10, 2, 3))) +def test_mvstudentt_method(): + def all_svd_method(fgraph): + found_one = False + for node in fgraph.toposort(): + if isinstance(node.op, pm.MvNormal): + found_one = True + if not node.op.method == "svd": + return False + return found_one # We want to fail if there were no MvNormal nodes + + x = pm.MvStudentT.dist(nu=4, scale=np.eye(3), method="svd") + assert x.type.shape == (3,) + assert all_svd_method(x.owner.op.fgraph) + + # Changing the size should preserve the method + resized_x = change_dist_size(x, (2,)) + assert resized_x.type.shape == (2, 3) + assert all_svd_method(resized_x.owner.op.fgraph) + + def test_precision_mv_normal_optimization(): rng = np.random.default_rng(sum(map(ord, "be precise"))) From 48ffc97dc08931928b6c28ead60756c52b35a2ff Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 6 Mar 2025 15:52:47 +0100 Subject: [PATCH 6/6] Try a new seed --- tests/sampling/test_mcmc_external.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/sampling/test_mcmc_external.py b/tests/sampling/test_mcmc_external.py index 4ab3ed5e87..2d32277061 100644 --- a/tests/sampling/test_mcmc_external.py +++ b/tests/sampling/test_mcmc_external.py @@ -81,7 +81,8 @@ def test_step_args(): nuts_sampler="numpyro", target_accept=0.5, nuts={"max_treedepth": 10}, - random_seed=1410, + random_seed=1411, + progressbar=False, ) npt.assert_almost_equal(idata.sample_stats.acceptance_rate.mean(), 0.5, decimal=1)