diff --git a/pymc_experimental/distributions/multivariate/r2d2m2cp.py b/pymc_experimental/distributions/multivariate/r2d2m2cp.py index 62c706de..1ef0e78d 100644 --- a/pymc_experimental/distributions/multivariate/r2d2m2cp.py +++ b/pymc_experimental/distributions/multivariate/r2d2m2cp.py @@ -13,7 +13,8 @@ # limitations under the License. -from typing import Sequence, Union +from collections import namedtuple +from typing import Sequence, Tuple, Union import numpy as np import pymc as pm @@ -22,14 +23,23 @@ __all__ = ["R2D2M2CP"] -def _psivar2musigma(psi: pt.TensorVariable, explained_var: pt.TensorVariable, psi_mask): +def _psivar2musigma( + psi: pt.TensorVariable, + explained_var: pt.TensorVariable, + psi_mask: Union[pt.TensorLike, None], +) -> Tuple[pt.TensorVariable, pt.TensorVariable]: + sign = pt.sign(psi - 0.5) + if psi_mask is not None: + # any computation might be ignored for ~psi_mask + # sign and explained_var are used + psi = pt.where(psi_mask, psi, 0.5) pi = pt.erfinv(2 * psi - 1) f = (1 / (2 * pi**2 + 1)) ** 0.5 sigma = explained_var**0.5 * f mu = sigma * pi * 2**0.5 if psi_mask is not None: return ( - pt.where(psi_mask, mu, pt.sign(pi) * explained_var**0.5), + pt.where(psi_mask, mu, sign * explained_var**0.5), pt.where(psi_mask, sigma, 0), ) else: @@ -47,7 +57,7 @@ def _R2D2M2CP_beta( psi_mask, dims: Union[str, Sequence[str]], centered=False, -): +) -> pt.TensorVariable: """R2D2M2CP beta prior. Parameters @@ -65,7 +75,7 @@ def _R2D2M2CP_beta( psi: tensor probability of a coefficients to be positive """ - explained_variance = phi * pt.expand_dims(r2 * output_sigma**2, -1) + explained_variance = phi * pt.expand_dims(r2 * output_sigma**2, (-1,)) mu_param, std_param = _psivar2musigma(psi, explained_variance, psi_mask=psi_mask) if not centered: with pm.Model(name): @@ -107,7 +117,10 @@ def _R2D2M2CP_beta( return beta -def _broadcast_as_dims(*values, dims): +def _broadcast_as_dims( + *values: np.ndarray, + dims: Sequence[str], +) -> Union[Tuple[np.ndarray, ...], np.ndarray]: model = pm.modelcontext(None) shape = [len(model.coords[d]) for d in dims] ret = tuple(np.broadcast_to(v, shape) for v in values) @@ -117,7 +130,12 @@ def _broadcast_as_dims(*values, dims): return ret -def _psi_masked(positive_probs, positive_probs_std, *, dims): +def _psi_masked( + positive_probs: pt.TensorLike, + positive_probs_std: pt.TensorLike, + *, + dims: Sequence[str], +) -> Tuple[Union[pt.TensorLike, None], pt.TensorVariable]: if not ( isinstance(positive_probs, pt.Constant) and isinstance(positive_probs_std, pt.Constant) ): @@ -152,7 +170,12 @@ def _psi_masked(positive_probs, positive_probs_std, *, dims): return mask, psi -def _psi(positive_probs, positive_probs_std, *, dims): +def _psi( + positive_probs: pt.TensorLike, + positive_probs_std: Union[pt.TensorLike, None], + *, + dims: Sequence[str], +) -> Tuple[Union[pt.TensorLike, None], pt.TensorVariable]: if positive_probs_std is not None: mask, psi = _psi_masked( positive_probs=pt.as_tensor(positive_probs), @@ -171,12 +194,12 @@ def _psi(positive_probs, positive_probs_std, *, dims): def _phi( - variables_importance, - variance_explained, - importance_concentration, + variables_importance: Union[pt.TensorLike, None], + variance_explained: Union[pt.TensorLike, None], + importance_concentration: Union[pt.TensorLike, None], *, - dims, -): + dims: Sequence[str], +) -> pt.TensorVariable: *broadcast_dims, dim = dims model = pm.modelcontext(None) if variables_importance is not None: @@ -200,47 +223,50 @@ def _phi( return phi +R2D2M2CPOut = namedtuple("R2D2M2CPOut", ["eps", "beta"]) + + def R2D2M2CP( - name, - output_sigma, - input_sigma, + name: str, + output_sigma: pt.TensorLike, + input_sigma: pt.TensorLike, *, - dims, - r2, - variables_importance=None, - variance_explained=None, - importance_concentration=None, - r2_std=None, - positive_probs=0.5, - positive_probs_std=None, - centered=False, -): + dims: Sequence[str], + r2: pt.TensorLike, + variables_importance: Union[pt.TensorLike, None] = None, + variance_explained: Union[pt.TensorLike, None] = None, + importance_concentration: Union[pt.TensorLike, None] = None, + r2_std: Union[pt.TensorLike, None] = None, + positive_probs: Union[pt.TensorLike, None] = 0.5, + positive_probs_std: Union[pt.TensorLike, None] = None, + centered: bool = False, +) -> R2D2M2CPOut: """R2D2M2CP Prior. Parameters ---------- name : str Name for the distribution - output_sigma : tensor + output_sigma : Tensor Output standard deviation - input_sigma : tensor + input_sigma : Tensor Input standard deviation dims : Union[str, Sequence[str]] Dims for the distribution - r2 : tensor + r2 : Tensor :math:`R^2` estimate - variables_importance : tensor, optional + variables_importance : Tensor, optional Optional estimate for variables importance, positive, by default None - variance_explained : tensor, optional + variance_explained : Tensor, optional Alternative estimate for variables importance which is point estimate of variance explained, should sum up to one, by default None - importance_concentration : tensor, optional + importance_concentration : Tensor, optional Confidence around variance explained or variable importance estimate - r2_std : tensor, optional + r2_std : Tensor, optional Optional uncertainty over :math:`R^2`, by default None - positive_probs : tensor, optional + positive_probs : Tensor, optional Optional probability of variables contribution to be positive, by default 0.5 - positive_probs_std : tensor, optional + positive_probs_std : Tensor, optional Optional uncertainty over effect direction probability, by default None centered : bool, optional Centered or Non-Centered parametrization of the distribution, by default Non-Centered. Advised to check both @@ -419,4 +445,4 @@ def R2D2M2CP( psi_mask=mask, ) resid_sigma = (1 - r2) ** 0.5 * output_sigma - return resid_sigma, beta + return R2D2M2CPOut(resid_sigma, beta) diff --git a/pymc_experimental/tests/distributions/test_multivariate.py b/pymc_experimental/tests/distributions/test_multivariate.py index 012dc3cc..fcf5cfec 100644 --- a/pymc_experimental/tests/distributions/test_multivariate.py +++ b/pymc_experimental/tests/distributions/test_multivariate.py @@ -1,11 +1,17 @@ import numpy as np import pymc as pm +import pytensor import pytest import pymc_experimental as pmx class TestR2D2M2CP: + @pytest.fixture(autouse=True) + def fast_compile(self): + with pytensor.config.change_flags(mode="FAST_COMPILE", exception_verbosity="high"): + yield + @pytest.fixture(autouse=True) def model(self): # every method is within a model @@ -95,17 +101,13 @@ def phi_args(self, request, phi_args_base): phi_args_base["importance_concentration"] = 10 return phi_args_base - def test_init( + def test_init_r2( self, dims, - centered, input_std, output_std, r2, r2_std, - positive_probs, - positive_probs_std, - phi_args, model: pm.Model, ): eps, beta = pmx.distributions.R2D2M2CP( @@ -115,10 +117,6 @@ def test_init( dims=dims, r2=r2, r2_std=r2_std, - centered=centered, - positive_probs_std=positive_probs_std, - positive_probs=positive_probs, - **phi_args ) assert not np.isnan(beta.eval()).any() assert eps.eval().shape == output_std.shape @@ -127,9 +125,63 @@ def test_init( assert "beta" in model.named_vars assert ("beta::r2" in model.named_vars) == (r2_std is not None), set(model.named_vars) # phi is only created if variable importance is not None and there is more than one var + assert np.isfinite(model.compile_logp()(model.initial_point())) + + def test_init_importance( + self, + dims, + centered, + input_std, + output_std, + phi_args, + model: pm.Model, + ): + eps, beta = pmx.distributions.R2D2M2CP( + "beta", + output_std, + input_std, + dims=dims, + r2=1, + centered=centered, + **phi_args, + ) + assert not np.isnan(beta.eval()).any() + assert eps.eval().shape == output_std.shape + assert beta.eval().shape == input_std.shape + # r2 rv is only created if r2 std is not None + assert "beta" in model.named_vars + # phi is only created if variable importance is not None and there is more than one var assert ("beta::phi" in model.named_vars) == ( "variables_importance" in phi_args or "importance_concentration" in phi_args ), set(model.named_vars) + assert np.isfinite(model.compile_logp()(model.initial_point())) + + def test_init_positive_probs( + self, + dims, + centered, + input_std, + output_std, + positive_probs, + positive_probs_std, + model: pm.Model, + ): + eps, beta = pmx.distributions.R2D2M2CP( + "beta", + output_std, + input_std, + dims=dims, + r2=1.0, + centered=centered, + positive_probs_std=positive_probs_std, + positive_probs=positive_probs, + ) + assert not np.isnan(beta.eval()).any() + assert eps.eval().shape == output_std.shape + assert beta.eval().shape == input_std.shape + # r2 rv is only created if r2 std is not None + assert "beta" in model.named_vars + # phi is only created if variable importance is not None and there is more than one var assert ("beta::psi" in model.named_vars) == ( positive_probs_std is not None and positive_probs_std.any() ), set(model.named_vars)