diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 31f22eac2a..48256afc77 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -49,20 +49,21 @@ jobs: - | pymc/tests/distributions/test_distribution.py - pymc/tests/distributions/test_bound.py - pymc/tests/distributions/test_censored.py pymc/tests/distributions/test_discrete.py pymc/tests/distributions/test_continuous.py pymc/tests/distributions/test_multivariate.py + + - | + pymc/tests/distributions/test_bound.py + pymc/tests/distributions/test_censored.py pymc/tests/distributions/test_simulator.py + pymc/tests/distributions/test_truncated.py - | pymc/tests/tuning/test_scaling.py pymc/tests/tuning/test_starting.py pymc/tests/test_shared.py pymc/tests/test_types.py - - - | pymc/tests/distributions/test_dist_math.py pymc/tests/distributions/test_transform.py pymc/tests/test_parallel_sampling.py diff --git a/pymc/distributions/__init__.py b/pymc/distributions/__init__.py index 3240bde379..6887321c78 100644 --- a/pymc/distributions/__init__.py +++ b/pymc/distributions/__init__.py @@ -110,6 +110,7 @@ MvStudentTRandomWalk, RandomWalk, ) +from pymc.distributions.truncated import Truncated __all__ = [ "Uniform", @@ -192,6 +193,7 @@ "Rice", "Moyal", "Simulator", + "Truncated", "Censored", "CAR", "PolyaGamma", diff --git a/pymc/distributions/bound.py b/pymc/distributions/bound.py index 3df0ff7fda..c23d555f72 100644 --- a/pymc/distributions/bound.py +++ b/pymc/distributions/bound.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import warnings import aesara.tensor as at import numpy as np @@ -182,6 +183,14 @@ def __new__( **kwargs, ): + warnings.warn( + "Bound has been deprecated in favor of Truncated, and will be removed in a " + "future release. If Truncated is not an option, Bound can be implemented by" + "adding an IntervalTransform between lower and upper to a continuous " + "variable. A Potential that returns negative infinity for values outside " + "of the bounds can be used for discrete variables.", + FutureWarning, + ) cls._argument_checks(dist, **kwargs) if dims is not None: diff --git a/pymc/distributions/truncated.py b/pymc/distributions/truncated.py new file mode 100644 index 0000000000..c0359ec64e --- /dev/null +++ b/pymc/distributions/truncated.py @@ -0,0 +1,342 @@ +from functools import singledispatch + +import aesara +import aesara.tensor as at +import numpy as np + +from aeppl.abstract import MeasurableVariable +from aeppl.logprob import _logcdf, _logprob, icdf, logcdf +from aesara import scan +from aesara.graph import Op +from aesara.graph.basic import Node +from aesara.raise_op import CheckAndRaise +from aesara.scan import until +from aesara.tensor import TensorConstant, TensorVariable +from aesara.tensor.random.basic import NormalRV +from aesara.tensor.random.op import RandomVariable + +from pymc.distributions.continuous import TruncatedNormal, bounded_cont_transform +from pymc.distributions.dist_math import check_parameters +from pymc.distributions.distribution import ( + Distribution, + SymbolicRandomVariable, + _moment, + moment, +) +from pymc.distributions.shape_utils import _change_dist_size, change_dist_size, to_tuple +from pymc.distributions.transforms import _default_transform +from pymc.exceptions import TruncationError +from pymc.math import logdiffexp +from pymc.util import check_dist_not_registered + + +class TruncatedRV(SymbolicRandomVariable): + """An `Op` constructed from an Aesara graph that represents a truncated univariate + random variable.""" + + default_output = 1 + base_rv_op = None + max_n_steps = None + + def __init__(self, *args, base_rv_op: Op, max_n_steps: int, **kwargs): + self.base_rv_op = base_rv_op + self.max_n_steps = max_n_steps + super().__init__(*args, **kwargs) + + def update(self, node: Node): + """Return the update mapping for the noise RV.""" + # Since RNG is a shared variable it shows up as the last node input + return {node.inputs[-1]: node.outputs[0]} + + +MeasurableVariable.register(TruncatedRV) + + +@singledispatch +def _truncated(op: Op, lower, upper, *params): + """Return the truncated equivalent of another `RandomVariable`.""" + raise NotImplementedError(f"{op} does not have an equivalent truncated version implemented") + + +class TruncationCheck(CheckAndRaise): + """Implements a check in truncated graphs. + Raises `TruncationError` if the check is not True. + """ + + def __init__(self, msg=""): + super().__init__(TruncationError, msg) + + def __str__(self): + return f"TruncationCheck{{{self.msg}}}" + + +class Truncated(Distribution): + r""" + Truncated distribution + + The pdf of a censored distribution is + + .. math:: + + \begin{cases} + 0 & \text{for } x < lower, \\ + \frac{\text{PDF}(x, dist)}{\text{CDF}(upper, dist) - \text{CDF}(lower, dist)} + & \text{for } lower <= x <= upper, \\ + 0 & \text{for } x > upper, + \end{cases} + + + Parameters + ---------- + dist: unnamed distribution + Univariate distribution created via the `.dist()` API, which will be truncated. + This distribution must be a pure RandomVariable and have a logcdf method + implemented for MCMC sampling. + + .. warning:: dist will be cloned, rendering it independent of the one passed as input. + + lower: tensor_like of float or None + Lower (left) truncation point. If `None` the distribution will not be left truncated. + upper: tensor_like of float or None + Upper (right) truncation point. If `None`, the distribution will not be right truncated. + max_n_steps: int, defaults 10_000 + Maximum number of resamples that are attempted when performing rejection sampling. + A `TruncationError` is raised if convergence is not reached after that many steps. + + Returns + ------- + truncated_distribution: TensorVariable + Graph representing a truncated `RandomVariable`. A specialized `Op` may be used + if the `Op` of the dist has a dispatched `_truncated` function. Otherwise, a + `SymbolicRandomVariable` graph representing the truncation process, via inverse + CDF sampling (if the underlying dist has a logcdf method), or rejection sampling + is returned. + """ + + rv_type = TruncatedRV + + @classmethod + def dist(cls, dist, lower=None, upper=None, max_n_steps: int = 10_000, **kwargs): + if not (isinstance(dist, TensorVariable) and isinstance(dist.owner.op, RandomVariable)): + if isinstance(dist.owner.op, SymbolicRandomVariable): + raise NotImplementedError( + f"Truncation not implemented for SymbolicRandomVariable {dist.owner.op}" + ) + raise ValueError( + f"Truncation dist must be a distribution created via the `.dist()` API, got {type(dist)}" + ) + + if dist.owner.op.ndim_supp > 0: + raise NotImplementedError("Truncation not implemented for multivariate distributions") + + check_dist_not_registered(dist) + + if lower is None and upper is None: + raise ValueError("lower and upper cannot both be None") + + return super().dist([dist, lower, upper, max_n_steps], **kwargs) + + @classmethod + def rv_op(cls, dist, lower, upper, max_n_steps, size=None): + + # Try to use specialized Op + try: + return _truncated(dist.owner.op, lower, upper, *dist.owner.inputs) + except NotImplementedError: + pass + + lower = at.as_tensor_variable(lower) if lower is not None else at.constant(-np.inf) + upper = at.as_tensor_variable(upper) if upper is not None else at.constant(np.inf) + + if size is None: + size = at.broadcast_shape(dist, lower, upper) + dist = change_dist_size(dist, new_size=size) + + # Variables with `_` suffix identify dummy inputs for the OpFromGraph + graph_inputs = [*dist.owner.inputs[1:], lower, upper] + graph_inputs_ = [inp.type() for inp in graph_inputs] + *rv_inputs_, lower_, upper_ = graph_inputs_ + + # We will use a Shared RNG variable because Scan demands it, even though it + # would not be necessary for the OpFromGraph inverse cdf. + rng = aesara.shared(np.random.default_rng()) + rv_ = dist.owner.op.make_node(rng, *rv_inputs_).default_output() + + # Try to use inverted cdf sampling + try: + # For left truncated discrete RVs, we need to include the whole lower bound. + # This may result in draws below the truncation range, if any uniform == 0 + lower_value = lower_ - 1 if dist.owner.op.dtype.startswith("int") else lower_ + cdf_lower_ = at.exp(logcdf(rv_, lower_value)) + cdf_upper_ = at.exp(logcdf(rv_, upper_)) + # It's okay to reuse the same rng here, because the rng in rv_ will not be + # used by either the logcdf of icdf functions + uniform_ = at.random.uniform( + cdf_lower_, + cdf_upper_, + rng=rng, + size=rv_inputs_[0], + ) + truncated_rv_ = icdf(rv_, uniform_) + return TruncatedRV( + base_rv_op=dist.owner.op, + inputs=graph_inputs_, + outputs=[uniform_.owner.outputs[0], truncated_rv_], + ndim_supp=0, + max_n_steps=max_n_steps, + )(*graph_inputs) + except NotImplementedError: + pass + + # Fallback to rejection sampling + def loop_fn(truncated_rv, reject_draws, lower, upper, rng, *rv_inputs): + next_rng, new_truncated_rv = dist.owner.op.make_node(rng, *rv_inputs).outputs + truncated_rv = at.set_subtensor( + truncated_rv[reject_draws], + new_truncated_rv[reject_draws], + ) + reject_draws = at.or_((truncated_rv < lower), (truncated_rv > upper)) + + return ( + (truncated_rv, reject_draws), + [(rng, next_rng)], + until(~at.any(reject_draws)), + ) + + (truncated_rv_, reject_draws_), updates = scan( + loop_fn, + outputs_info=[ + at.zeros_like(rv_), + at.ones_like(rv_, dtype=bool), + ], + non_sequences=[lower_, upper_, rng, *rv_inputs_], + n_steps=max_n_steps, + strict=True, + ) + + truncated_rv_ = truncated_rv_[-1] + convergence_ = ~at.any(reject_draws_[-1]) + truncated_rv_ = TruncationCheck(f"Truncation did not converge in {max_n_steps} steps")( + truncated_rv_, convergence_ + ) + + return TruncatedRV( + base_rv_op=dist.owner.op, + inputs=graph_inputs_, + outputs=[tuple(updates.values())[0], truncated_rv_], + ndim_supp=0, + max_n_steps=max_n_steps, + )(*graph_inputs) + + +@_change_dist_size.register(TruncatedRV) +def change_truncated_size(op, dist, new_size, expand): + *rv_inputs, lower, upper, rng = dist.owner.inputs + # Recreate the original untruncated RV + untruncated_rv = op.base_rv_op.make_node(rng, *rv_inputs).default_output() + if expand: + new_size = to_tuple(new_size) + tuple(dist.shape) + + return Truncated.rv_op( + untruncated_rv, + lower=lower, + upper=upper, + size=new_size, + max_n_steps=op.max_n_steps, + ) + + +@_moment.register(TruncatedRV) +def truncated_moment(op, rv, *inputs): + *rv_inputs, lower, upper, rng = inputs + + # recreate untruncated rv and respective moment + untruncated_rv = op.base_rv_op.make_node(rng, *rv_inputs).default_output() + untruncated_moment = moment(untruncated_rv) + + fallback_moment = at.switch( + at.and_(at.bitwise_not(at.isinf(lower)), at.bitwise_not(at.isinf(upper))), + (upper - lower) / 2, # lower and upper are finite + at.switch( + at.isinf(upper), + lower + 1, # only lower is finite + upper - 1, # only upper is finite + ), + ) + + return at.switch( + at.and_(at.ge(untruncated_moment, lower), at.le(untruncated_moment, upper)), + untruncated_moment, # untruncated moment is between lower and upper + fallback_moment, + ) + + +@_default_transform.register(TruncatedRV) +def truncated_default_transform(op, rv): + # Don't transform discrete truncated distributions + if op.base_rv_op.dtype.startswith("int"): + return None + # Lower and Upper are the arguments -3 and -2 + return bounded_cont_transform(op, rv, bound_args_indices=(-3, -2)) + + +@_logprob.register(TruncatedRV) +def truncated_logprob(op, values, *inputs, **kwargs): + (value,) = values + + *rv_inputs, lower, upper, rng = inputs + rv_inputs = [rng, *rv_inputs] + + base_rv_op = op.base_rv_op + logp = _logprob(base_rv_op, (value,), *rv_inputs, **kwargs) + # For left truncated RVs, we don't want to include the lower bound in the + # normalization term + lower_value = lower - 1 if base_rv_op.dtype.startswith("int") else lower + lower_logcdf = _logcdf(base_rv_op, lower_value, *rv_inputs, **kwargs) + upper_logcdf = _logcdf(base_rv_op, upper, *rv_inputs, **kwargs) + + if base_rv_op.name: + logp.name = f"{base_rv_op}_logprob" + lower_logcdf.name = f"{base_rv_op}_lower_logcdf" + upper_logcdf.name = f"{base_rv_op}_upper_logcdf" + + is_lower_bounded = not (isinstance(lower, TensorConstant) and np.all(np.isneginf(lower.value))) + is_upper_bounded = not (isinstance(upper, TensorConstant) and np.all(np.isinf(upper.value))) + + lognorm = 0 + if is_lower_bounded and is_upper_bounded: + lognorm = logdiffexp(upper_logcdf, lower_logcdf) + elif is_lower_bounded: + lognorm = at.log1mexp(lower_logcdf) + elif is_upper_bounded: + lognorm = upper_logcdf + + logp = logp - lognorm + + if is_lower_bounded: + logp = at.switch(value < lower, -np.inf, logp) + + if is_upper_bounded: + logp = at.switch(value <= upper, logp, -np.inf) + + if is_lower_bounded and is_upper_bounded: + logp = check_parameters( + logp, + at.le(lower, upper), + msg="lower_bound <= upper_bound", + ) + + return logp + + +@_truncated.register(NormalRV) +def _truncated_normal(op, lower, upper, rng, size, dtype, mu, sigma): + return TruncatedNormal.dist( + mu=mu, + sigma=sigma, + lower=lower, + upper=upper, + rng=None, # Do not reuse rng to avoid weird dependencies + size=size, + dtype=dtype, + ) diff --git a/pymc/exceptions.py b/pymc/exceptions.py index 1d44bb8865..0d7ba3eaaf 100644 --- a/pymc/exceptions.py +++ b/pymc/exceptions.py @@ -74,3 +74,7 @@ def __init__(self, message, actual=None, expected=None): super().__init__(f"{message} (expected {expected})") else: super().__init__(message) + + +class TruncationError(Exception): + """Exception for errors generated from truncated graphs""" diff --git a/pymc/tests/distributions/test_bound.py b/pymc/tests/distributions/test_bound.py index a1301740ee..57773b2577 100644 --- a/pymc/tests/distributions/test_bound.py +++ b/pymc/tests/distributions/test_bound.py @@ -31,20 +31,21 @@ class TestBound: def test_continuous(self): with pm.Model() as model: dist = pm.Normal.dist(mu=0, sigma=1) - with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", "invalid value encountered in add", RuntimeWarning - ) - UnboundedNormal = pm.Bound("unbound", dist, transform=None) - InfBoundedNormal = pm.Bound( - "infbound", dist, lower=-np.inf, upper=np.inf, transform=None - ) - LowerNormal = pm.Bound("lower", dist, lower=0, transform=None) - UpperNormal = pm.Bound("upper", dist, upper=0, transform=None) - BoundedNormal = pm.Bound("bounded", dist, lower=1, upper=10, transform=None) - LowerNormalTransform = pm.Bound("lowertrans", dist, lower=1) - UpperNormalTransform = pm.Bound("uppertrans", dist, upper=10) - BoundedNormalTransform = pm.Bound("boundedtrans", dist, lower=1, upper=10) + with pytest.warns(FutureWarning, match="Bound has been deprecated"): + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", "invalid value encountered in add", RuntimeWarning + ) + UnboundedNormal = pm.Bound("unbound", dist, transform=None) + InfBoundedNormal = pm.Bound( + "infbound", dist, lower=-np.inf, upper=np.inf, transform=None + ) + LowerNormal = pm.Bound("lower", dist, lower=0, transform=None) + UpperNormal = pm.Bound("upper", dist, upper=0, transform=None) + BoundedNormal = pm.Bound("bounded", dist, lower=1, upper=10, transform=None) + LowerNormalTransform = pm.Bound("lowertrans", dist, lower=1) + UpperNormalTransform = pm.Bound("uppertrans", dist, upper=10) + BoundedNormalTransform = pm.Bound("boundedtrans", dist, lower=1, upper=10) assert joint_logp(LowerNormal, -1).eval() == -np.inf assert joint_logp(UpperNormal, 1).eval() == -np.inf @@ -73,14 +74,15 @@ def test_continuous(self): def test_discrete(self): with pm.Model() as model: dist = pm.Poisson.dist(mu=4) - with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", "invalid value encountered in add", RuntimeWarning - ) - UnboundedPoisson = pm.Bound("unbound", dist) - LowerPoisson = pm.Bound("lower", dist, lower=1) - UpperPoisson = pm.Bound("upper", dist, upper=10) - BoundedPoisson = pm.Bound("bounded", dist, lower=1, upper=10) + with pytest.warns(FutureWarning, match="Bound has been deprecated"): + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", "invalid value encountered in add", RuntimeWarning + ) + UnboundedPoisson = pm.Bound("unbound", dist) + LowerPoisson = pm.Bound("lower", dist, lower=1) + UpperPoisson = pm.Bound("upper", dist, upper=10) + BoundedPoisson = pm.Bound("bounded", dist, lower=1, upper=10) assert joint_logp(LowerPoisson, 0).eval() == -np.inf assert joint_logp(UpperPoisson, 11).eval() == -np.inf @@ -118,8 +120,9 @@ def test_arguments_checks(self): msg = "Observed Bound distributions are not supported" with pm.Model() as m: x = pm.Normal("x", 0, 1) - with pytest.raises(ValueError, match=msg): - pm.Bound("bound", x, observed=5) + with pytest.warns(FutureWarning, match="Bound has been deprecated"): + with pytest.raises(ValueError, match=msg): + pm.Bound("bound", x, observed=5) msg = "Cannot transform discrete variable." with pm.Model() as m: @@ -128,52 +131,60 @@ def test_arguments_checks(self): warnings.filterwarnings( "ignore", "invalid value encountered in add", RuntimeWarning ) - with pytest.raises(ValueError, match=msg): - pm.Bound("bound", x, transform=pm.distributions.transforms.log) + with pytest.warns(FutureWarning, match="Bound has been deprecated"): + with pytest.raises(ValueError, match=msg): + pm.Bound("bound", x, transform=pm.distributions.transforms.log) msg = "Given dims do not exist in model coordinates." with pm.Model() as m: x = pm.Poisson.dist(0.5) - with pytest.raises(ValueError, match=msg): - pm.Bound("bound", x, dims="random_dims") + with pytest.warns(FutureWarning, match="Bound has been deprecated"): + with pytest.raises(ValueError, match=msg): + pm.Bound("bound", x, dims="random_dims") msg = "The dist x was already registered in the current model" with pm.Model() as m: x = pm.Normal("x", 0, 1) - with pytest.raises(ValueError, match=msg): - pm.Bound("bound", x) + with pytest.warns(FutureWarning, match="Bound has been deprecated"): + with pytest.raises(ValueError, match=msg): + pm.Bound("bound", x) msg = "Passing a distribution class to `Bound` is no longer supported" with pm.Model() as m: - with pytest.raises(ValueError, match=msg): - pm.Bound("bound", pm.Normal) + with pytest.warns(FutureWarning, match="Bound has been deprecated"): + with pytest.raises(ValueError, match=msg): + pm.Bound("bound", pm.Normal) msg = "Bounding of MultiVariate RVs is not yet supported" with pm.Model() as m: x = pm.MvNormal.dist(np.zeros(3), np.eye(3)) - with pytest.raises(NotImplementedError, match=msg): - pm.Bound("bound", x) + with pytest.warns(FutureWarning, match="Bound has been deprecated"): + with pytest.raises(NotImplementedError, match=msg): + pm.Bound("bound", x) msg = "must be a Discrete or Continuous distribution subclass" with pm.Model() as m: x = self.create_invalid_distribution().dist() - with pytest.raises(ValueError, match=msg): - pm.Bound("bound", x) + with pytest.warns(FutureWarning, match="Bound has been deprecated"): + with pytest.raises(ValueError, match=msg): + pm.Bound("bound", x) def test_invalid_sampling(self): msg = "Cannot sample from a bounded variable" with pm.Model() as m: dist = pm.Normal.dist(mu=0, sigma=1) - BoundedNormal = pm.Bound("bounded", dist, lower=1, upper=10) + with pytest.warns(FutureWarning, match="Bound has been deprecated"): + BoundedNormal = pm.Bound("bounded", dist, lower=1, upper=10) with pytest.raises(NotImplementedError, match=msg): pm.sample_prior_predictive() def test_bound_shapes(self): with pm.Model(coords={"sample": np.ones((2, 5))}) as m: dist = pm.Normal.dist(mu=0, sigma=1) - bound_sized = pm.Bound("boundedsized", dist, lower=1, upper=10, size=(4, 5)) - bound_shaped = pm.Bound("boundedshaped", dist, lower=1, upper=10, shape=(3, 5)) - bound_dims = pm.Bound("boundeddims", dist, lower=1, upper=10, dims="sample") + with pytest.warns(FutureWarning, match="Bound has been deprecated"): + bound_sized = pm.Bound("boundedsized", dist, lower=1, upper=10, size=(4, 5)) + bound_shaped = pm.Bound("boundedshaped", dist, lower=1, upper=10, shape=(3, 5)) + bound_dims = pm.Bound("boundeddims", dist, lower=1, upper=10, dims="sample") initial_point = m.initial_point() dist_size = initial_point["boundedsized_interval__"].shape @@ -198,13 +209,16 @@ def test_bound_dist(self): def test_array_bound(self): with pm.Model() as model: dist = pm.Normal.dist() - with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", "invalid value encountered in add", RuntimeWarning + with pytest.warns(FutureWarning, match="Bound has been deprecated"): + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", "invalid value encountered in add", RuntimeWarning + ) + LowerPoisson = pm.Bound("lower", dist, lower=[1, None], transform=None) + UpperPoisson = pm.Bound("upper", dist, upper=[np.inf, 10], transform=None) + BoundedPoisson = pm.Bound( + "bounded", dist, lower=[1, 2], upper=[9, 10], transform=None ) - LowerPoisson = pm.Bound("lower", dist, lower=[1, None], transform=None) - UpperPoisson = pm.Bound("upper", dist, upper=[np.inf, 10], transform=None) - BoundedPoisson = pm.Bound("bounded", dist, lower=[1, 2], upper=[9, 10], transform=None) first, second = joint_logp(LowerPoisson, [0, 0], sum=False)[0].eval() assert first == -np.inf diff --git a/pymc/tests/distributions/test_truncated.py b/pymc/tests/distributions/test_truncated.py new file mode 100644 index 0000000000..dc589d8255 --- /dev/null +++ b/pymc/tests/distributions/test_truncated.py @@ -0,0 +1,300 @@ +import aesara +import aesara.tensor as at +import numpy as np +import pytest +import scipy + +from aeppl.logprob import ParameterValueError, _icdf +from aeppl.transforms import IntervalTransform +from aesara.tensor.random.basic import GeometricRV, NormalRV + +from pymc import Censored, Model, draw, find_MAP, logp +from pymc.distributions.continuous import Exponential, TruncatedNormalRV +from pymc.distributions.shape_utils import change_dist_size +from pymc.distributions.transforms import _default_transform +from pymc.distributions.truncated import Truncated, TruncatedRV, _truncated +from pymc.exceptions import TruncationError +from pymc.tests.distributions.util import assert_moment_is_expected + + +class IcdfNormalRV(NormalRV): + """Normal RV that has icdf but not truncated dispatching""" + + +class RejectionNormalRV(NormalRV): + """Normal RV that has neither icdf nor truncated dispatching.""" + + +class IcdfGeometricRV(GeometricRV): + """Geometric RV that has icdf but not truncated dispatching.""" + + +class RejectionGeometricRV(GeometricRV): + """Geometric RV that has neither icdf nor truncated dispatching.""" + + +icdf_normal = no_moment_normal = IcdfNormalRV() +rejection_normal = RejectionNormalRV() +icdf_geometric = IcdfGeometricRV() +rejection_geometric = RejectionGeometricRV() + + +@_truncated.register(IcdfNormalRV) +@_truncated.register(RejectionNormalRV) +@_truncated.register(IcdfGeometricRV) +@_truncated.register(RejectionGeometricRV) +def _truncated_not_implemented(*args, **kwargs): + raise NotImplementedError() + + +@_icdf.register(RejectionNormalRV) +@_icdf.register(RejectionGeometricRV) +def _icdf_not_implemented(*args, **kwargs): + raise NotImplementedError() + + +def test_truncation_specialized_op(): + rng = aesara.shared(np.random.default_rng()) + x = at.random.normal(0, 10, rng=rng, name="x") + + xt = Truncated.dist(x, lower=5, upper=15, shape=(100,)) + assert isinstance(xt.owner.op, TruncatedNormalRV) + + # Test RNG is not reused + assert xt.owner.inputs[0] is not rng + + lower_upper = at.stack(xt.owner.inputs[5:]) + assert np.all(lower_upper.eval() == [5, 15]) + + +@pytest.mark.parametrize("lower, upper", [(-1, np.inf), (-1, 1.5), (-np.inf, 1.5)]) +@pytest.mark.parametrize("op_type", ["icdf", "rejection"]) +def test_truncation_continuous_random(op_type, lower, upper): + loc = 0.15 + scale = 10 + normal_op = icdf_normal if op_type == "icdf" else rejection_normal + x = normal_op(loc, scale, name="x", size=100) + + xt = Truncated.dist(x, lower=lower, upper=upper) + assert isinstance(xt.owner.op, TruncatedRV) + assert xt.type == x.type + + xt_draws = draw(xt, draws=5) + assert np.all(xt_draws >= lower) + assert np.all(xt_draws <= upper) + assert np.unique(xt_draws).size == xt_draws.size + + # Compare with reference + ref_xt = scipy.stats.truncnorm( + (lower - loc) / scale, + (upper - loc) / scale, + loc, + scale, + ) + assert scipy.stats.cramervonmises(xt_draws.ravel(), ref_xt.cdf).pvalue > 0.001 + + # Test max_n_steps + xt = Truncated.dist(x, lower=lower, upper=upper, max_n_steps=1) + if op_type == "icdf": + xt_draws = draw(xt) + assert np.all(xt_draws >= lower) + assert np.all(xt_draws <= upper) + assert np.unique(xt_draws).size == xt_draws.size + else: + with pytest.raises(TruncationError, match="^Truncation did not converge"): + draw(xt) + + +@pytest.mark.parametrize("lower, upper", [(-1, np.inf), (-1, 1.5), (-np.inf, 1.5)]) +@pytest.mark.parametrize("op_type", ["icdf", "rejection"]) +def test_truncation_continuous_logp(op_type, lower, upper): + loc = 0.15 + scale = 10 + op = icdf_normal if op_type == "icdf" else rejection_normal + + x = op(loc, scale, name="x") + xt = Truncated.dist(x, lower=lower, upper=upper) + assert isinstance(xt.owner.op, TruncatedRV) + + xt_vv = xt.clone() + xt_logp_fn = aesara.function([xt_vv], logp(xt, xt_vv)) + + ref_xt = scipy.stats.truncnorm( + (lower - loc) / scale, + (upper - loc) / scale, + loc, + scale, + ) + for bound in (lower, upper): + if np.isinf(bound): + return + for offset in (-1, 0, 1): + test_xt_v = bound + offset + assert np.isclose(xt_logp_fn(test_xt_v), ref_xt.logpdf(test_xt_v)) + + +@pytest.mark.parametrize("lower, upper", [(2, np.inf), (2, 5), (-np.inf, 5)]) +@pytest.mark.parametrize("op_type", ["icdf", "rejection"]) +def test_truncation_discrete_random(op_type, lower, upper): + p = 0.2 + geometric_op = icdf_geometric if op_type == "icdf" else rejection_geometric + + x = geometric_op(p, name="x", size=500) + xt = Truncated.dist(x, lower=lower, upper=upper) + assert isinstance(xt.owner.op, TruncatedRV) + assert xt.type == x.type + + xt_draws = draw(xt) + assert np.all(xt_draws >= lower) + assert np.all(xt_draws <= upper) + assert np.any(xt_draws == (max(1, lower))) + if upper != np.inf: + assert np.any(xt_draws == upper) + + # Test max_n_steps + xt = Truncated.dist(x, lower=lower, upper=upper, max_n_steps=3) + if op_type == "icdf": + xt_draws = draw(xt) + assert np.all(xt_draws >= lower) + assert np.all(xt_draws <= upper) + assert np.any(xt_draws == (max(1, lower))) + if upper != np.inf: + assert np.any(xt_draws == upper) + else: + with pytest.raises(TruncationError, match="^Truncation did not converge"): + draw(xt) + + +@pytest.mark.parametrize("lower, upper", [(2, np.inf), (2, 5), (-np.inf, 5)]) +@pytest.mark.parametrize("op_type", ["icdf", "rejection"]) +def test_truncation_discrete_logp(op_type, lower, upper): + p = 0.7 + op = icdf_geometric if op_type == "icdf" else rejection_geometric + + x = op(p, name="x") + xt = Truncated.dist(x, lower=lower, upper=upper) + assert isinstance(xt.owner.op, TruncatedRV) + + xt_vv = xt.clone() + xt_logp_fn = aesara.function([xt_vv], logp(xt, xt_vv)) + + ref_xt = scipy.stats.geom(p) + log_norm = np.log(ref_xt.cdf(upper) - ref_xt.cdf(lower - 1)) + + def ref_xt_logpmf(value): + if value < lower or value > upper: + return -np.inf + return ref_xt.logpmf(value) - log_norm + + for bound in (lower, upper): + if np.isinf(bound): + continue + for offset in (-1, 0, 1): + test_xt_v = bound + offset + assert np.isclose(xt_logp_fn(test_xt_v), ref_xt_logpmf(test_xt_v)) + + # Check that it integrates to 1 + log_integral = scipy.special.logsumexp([xt_logp_fn(v) for v in range(min(upper + 1, 20))]) + assert np.isclose(log_integral, 0.0, atol=1e-5) + + +def test_truncation_exceptions(): + with pytest.raises(ValueError, match="lower and upper cannot both be None"): + Truncated.dist(at.random.normal()) + + # Truncation does not work with SymbolicRV inputs + with pytest.raises( + NotImplementedError, + match="Truncation not implemented for SymbolicRandomVariable CensoredRV", + ): + Truncated.dist(Censored.dist(at.random.normal(), lower=-1, upper=1), -1, 1) + + with pytest.raises( + NotImplementedError, + match="Truncation not implemented for multivariate distributions", + ): + Truncated.dist(at.random.dirichlet([1, 1, 1]), -1, 1) + + +def test_truncation_logprob_bound_check(): + x = at.random.normal(name="x") + xt = Truncated.dist(x, lower=5, upper=-5) + with pytest.raises(ParameterValueError): + logp(xt, 0).eval() + + +def test_change_truncated_size(): + x = Truncated.dist(icdf_normal(0, [1, 2, 3]), lower=-1, size=(2, 3)) + x.eval().shape == (2, 3) + + new_x = change_dist_size(x, (4, 3)) + assert isinstance(new_x.owner.op, TruncatedRV) + new_x.eval().shape == (4, 3) + + new_x = change_dist_size(x, (4, 3), expand=True) + assert isinstance(new_x.owner.op, TruncatedRV) + new_x.eval().shape == (4, 3, 2, 3) + + +def test_truncated_default_transform(): + base_dist = rejection_geometric(1) + x = Truncated.dist(base_dist, lower=None, upper=5) + assert _default_transform(x.owner.op, x) is None + + base_dist = rejection_normal(0, 1) + x = Truncated.dist(base_dist, lower=None, upper=5) + assert isinstance(_default_transform(x.owner.op, x), IntervalTransform) + + +def test_truncated_transform_logp(): + with Model() as m: + base_dist = rejection_normal(0, 1) + x = Truncated("x", base_dist, lower=0, upper=None, transform=None) + y = Truncated("y", base_dist, lower=0, upper=None) + logp_eval = m.compile_logp(sum=False)({"x": -1, "y_interval__": -1}) + assert logp_eval[0] == -np.inf + assert np.isfinite(logp_eval[1]) + + +@pytest.mark.parametrize( + "truncated_dist, lower, upper, shape, expected", + [ + # Moment of truncated dist can be used + (icdf_normal(0, 1), -1, 2, None, 0), + # Moment of truncated dist cannot be used, both bounds are finite + (icdf_normal(3, 1), -1, 2, (2,), np.full((2,), 3 / 2)), + # Moment of truncated dist cannot be used, lower bound is finite + (icdf_normal(-3, 1), -1, None, (2, 3), np.full((2, 3), 0)), + # Moment of truncated dist can be used for 1st and 3rd mus, upper bound is finite + (icdf_normal([0, 3, 3], 1), None, [2, 2, 4], (4, 3), np.full((4, 3), [0, 1, 3])), + ], +) +def test_truncated_moment(truncated_dist, lower, upper, shape, expected): + with Model() as model: + Truncated("x", dist=truncated_dist, lower=lower, upper=upper, shape=shape) + assert_moment_is_expected(model, expected) + + +def test_truncated_inference(): + # exercise 3.3, p 47, from David MacKay Information Theory book + lam_true = 3 + lower = 0 + upper = 5 + + rng = np.random.default_rng(260) + x = rng.exponential(lam_true, size=5000) + obs = x[np.where(~((x < lower) | (x > upper)))] # remove values outside range + + with Model() as m: + lam = Exponential("lam", lam=1 / 5) # prior exponential with mean of 5 + Truncated( + "x", + dist=Exponential.dist(lam=1 / lam), + lower=lower, + upper=upper, + observed=obs, + ) + + map = find_MAP(progressbar=False) + + assert np.isclose(map["lam"], lam_true, atol=0.1)