|
| 1 | +from functools import singledispatch |
| 2 | + |
| 3 | +import aesara |
| 4 | +import aesara.tensor as at |
| 5 | +import numpy as np |
| 6 | + |
| 7 | +from aeppl.abstract import MeasurableVariable |
| 8 | +from aeppl.logprob import _logcdf, _logprob, icdf, logcdf |
| 9 | +from aesara import scan |
| 10 | +from aesara.graph import Op |
| 11 | +from aesara.graph.basic import Node |
| 12 | +from aesara.raise_op import CheckAndRaise |
| 13 | +from aesara.scan import until |
| 14 | +from aesara.tensor import TensorConstant, TensorVariable |
| 15 | +from aesara.tensor.random.basic import NormalRV |
| 16 | +from aesara.tensor.random.op import RandomVariable |
| 17 | + |
| 18 | +from pymc.distributions.continuous import TruncatedNormal, bounded_cont_transform |
| 19 | +from pymc.distributions.dist_math import check_parameters |
| 20 | +from pymc.distributions.distribution import ( |
| 21 | + Distribution, |
| 22 | + SymbolicRandomVariable, |
| 23 | + _moment, |
| 24 | + moment, |
| 25 | +) |
| 26 | +from pymc.distributions.shape_utils import _change_dist_size, change_dist_size, to_tuple |
| 27 | +from pymc.distributions.transforms import _default_transform |
| 28 | +from pymc.exceptions import TruncationError |
| 29 | +from pymc.math import logdiffexp |
| 30 | +from pymc.util import check_dist_not_registered |
| 31 | + |
| 32 | + |
| 33 | +class TruncatedRV(SymbolicRandomVariable): |
| 34 | + """An `Op` constructed from an Aesara graph that represents a truncated univariate |
| 35 | + random variable.""" |
| 36 | + |
| 37 | + default_output = 1 |
| 38 | + base_rv_op = None |
| 39 | + max_n_steps = None |
| 40 | + |
| 41 | + def __init__(self, *args, base_rv_op: Op, max_n_steps: int, **kwargs): |
| 42 | + self.base_rv_op = base_rv_op |
| 43 | + self.max_n_steps = max_n_steps |
| 44 | + super().__init__(*args, **kwargs) |
| 45 | + |
| 46 | + def update(self, node: Node): |
| 47 | + """Return the update mapping for the noise RV.""" |
| 48 | + # Since RNG is a shared variable it shows up as the last node input |
| 49 | + return {node.inputs[-1]: node.outputs[0]} |
| 50 | + |
| 51 | + |
| 52 | +MeasurableVariable.register(TruncatedRV) |
| 53 | + |
| 54 | + |
| 55 | +@singledispatch |
| 56 | +def _truncated(op: Op, lower, upper, *params): |
| 57 | + """Return the truncated equivalent of another `RandomVariable`.""" |
| 58 | + raise NotImplementedError(f"{op} does not have an equivalent truncated version implemented") |
| 59 | + |
| 60 | + |
| 61 | +class TruncationCheck(CheckAndRaise): |
| 62 | + """Implements a check in truncated graphs. |
| 63 | + Raises `TruncationError` if the check is not True. |
| 64 | + """ |
| 65 | + |
| 66 | + def __init__(self, msg=""): |
| 67 | + super().__init__(TruncationError, msg) |
| 68 | + |
| 69 | + def __str__(self): |
| 70 | + return f"TruncationCheck{{{self.msg}}}" |
| 71 | + |
| 72 | + |
| 73 | +class Truncated(Distribution): |
| 74 | + r""" |
| 75 | + Truncated distribution |
| 76 | +
|
| 77 | + The pdf of a censored distribution is |
| 78 | +
|
| 79 | + .. math:: |
| 80 | +
|
| 81 | + \begin{cases} |
| 82 | + 0 & \text{for } x < lower, \\ |
| 83 | + \frac{\text{PDF}(x, dist)}{\text{CDF}(upper, dist) - \text{CDF}(lower, dist)} |
| 84 | + & \text{for } lower <= x <= upper, \\ |
| 85 | + 0 & \text{for } x > upper, |
| 86 | + \end{cases} |
| 87 | +
|
| 88 | +
|
| 89 | + Parameters |
| 90 | + ---------- |
| 91 | + dist: unnamed distribution |
| 92 | + Univariate distribution created via the `.dist()` API, which will be truncated. |
| 93 | + This distribution must be a pure RandomVariable and have a logcdf method |
| 94 | + implemented for MCMC sampling. |
| 95 | +
|
| 96 | + .. warning:: dist will be cloned, rendering it independent of the one passed as input. |
| 97 | +
|
| 98 | + lower: tensor_like of float or None |
| 99 | + Lower (left) truncation point. If `None` the distribution will not be left truncated. |
| 100 | + upper: tensor_like of float or None |
| 101 | + Upper (right) truncation point. If `None`, the distribution will not be right truncated. |
| 102 | + max_n_steps: int, defaults 10_000 |
| 103 | + Maximum number of resamples that are attempted when performing rejection sampling. |
| 104 | + A `TruncationError` is raised if convergence is not reached after that many steps. |
| 105 | +
|
| 106 | + Returns |
| 107 | + ------- |
| 108 | + truncated_distribution: TensorVariable |
| 109 | + Graph representing a truncated `RandomVariable`. A specialized `Op` may be used |
| 110 | + if the `Op` of the dist has a dispatched `_truncated` function. Otherwise, a |
| 111 | + `SymbolicRandomVariable` graph representing the truncation process, via inverse |
| 112 | + CDF sampling (if the underlying dist has a logcdf method), or rejection sampling |
| 113 | + is returned. |
| 114 | + """ |
| 115 | + |
| 116 | + rv_type = TruncatedRV |
| 117 | + |
| 118 | + @classmethod |
| 119 | + def dist(cls, dist, lower=None, upper=None, max_n_steps: int = 10_000, **kwargs): |
| 120 | + if not (isinstance(dist, TensorVariable) and isinstance(dist.owner.op, RandomVariable)): |
| 121 | + if isinstance(dist.owner.op, SymbolicRandomVariable): |
| 122 | + raise NotImplementedError( |
| 123 | + f"Truncation not implemented for SymbolicRandomVariable {dist.owner.op}" |
| 124 | + ) |
| 125 | + raise ValueError( |
| 126 | + f"Truncation dist must be a distribution created via the `.dist()` API, got {type(dist)}" |
| 127 | + ) |
| 128 | + |
| 129 | + if dist.owner.op.ndim_supp > 0: |
| 130 | + raise NotImplementedError("Truncation not implemented for multivariate distributions") |
| 131 | + |
| 132 | + check_dist_not_registered(dist) |
| 133 | + |
| 134 | + if lower is None and upper is None: |
| 135 | + raise ValueError("lower and upper cannot both be None") |
| 136 | + |
| 137 | + return super().dist([dist, lower, upper, max_n_steps], **kwargs) |
| 138 | + |
| 139 | + @classmethod |
| 140 | + def rv_op(cls, dist, lower, upper, max_n_steps, size=None): |
| 141 | + |
| 142 | + # Try to use specialized Op |
| 143 | + try: |
| 144 | + return _truncated(dist.owner.op, lower, upper, *dist.owner.inputs) |
| 145 | + except NotImplementedError: |
| 146 | + pass |
| 147 | + |
| 148 | + lower = at.as_tensor_variable(lower) if lower is not None else at.constant(-np.inf) |
| 149 | + upper = at.as_tensor_variable(upper) if upper is not None else at.constant(np.inf) |
| 150 | + |
| 151 | + if size is None: |
| 152 | + size = at.broadcast_shape(dist, lower, upper) |
| 153 | + dist = change_dist_size(dist, new_size=size) |
| 154 | + |
| 155 | + # Variables with `_` suffix identify dummy inputs for the OpFromGraph |
| 156 | + graph_inputs = [*dist.owner.inputs[1:], lower, upper] |
| 157 | + graph_inputs_ = [inp.type() for inp in graph_inputs] |
| 158 | + *rv_inputs_, lower_, upper_ = graph_inputs_ |
| 159 | + |
| 160 | + # We will use a Shared RNG variable because Scan demands it, even though it |
| 161 | + # would not be necessary for the OpFromGraph inverse cdf. |
| 162 | + rng = aesara.shared(np.random.default_rng()) |
| 163 | + rv_ = dist.owner.op.make_node(rng, *rv_inputs_).default_output() |
| 164 | + |
| 165 | + # Try to use inverted cdf sampling |
| 166 | + try: |
| 167 | + # For left truncated discrete RVs, we need to include the whole lower bound. |
| 168 | + # This may result in draws below the truncation range, if any uniform == 0 |
| 169 | + lower_value = lower_ - 1 if dist.owner.op.dtype.startswith("int") else lower_ |
| 170 | + cdf_lower_ = at.exp(logcdf(rv_, lower_value)) |
| 171 | + cdf_upper_ = at.exp(logcdf(rv_, upper_)) |
| 172 | + # It's okay to reuse the same rng here, because the rng in rv_ will not be |
| 173 | + # used by either the logcdf of icdf functions |
| 174 | + uniform_ = at.random.uniform( |
| 175 | + cdf_lower_, |
| 176 | + cdf_upper_, |
| 177 | + rng=rng, |
| 178 | + size=rv_inputs_[0], |
| 179 | + ) |
| 180 | + truncated_rv_ = icdf(rv_, uniform_) |
| 181 | + return TruncatedRV( |
| 182 | + base_rv_op=dist.owner.op, |
| 183 | + inputs=graph_inputs_, |
| 184 | + outputs=[uniform_.owner.outputs[0], truncated_rv_], |
| 185 | + ndim_supp=0, |
| 186 | + max_n_steps=max_n_steps, |
| 187 | + )(*graph_inputs) |
| 188 | + except NotImplementedError: |
| 189 | + pass |
| 190 | + |
| 191 | + # Fallback to rejection sampling |
| 192 | + def loop_fn(truncated_rv, reject_draws, lower, upper, rng, *rv_inputs): |
| 193 | + next_rng, new_truncated_rv = dist.owner.op.make_node(rng, *rv_inputs).outputs |
| 194 | + truncated_rv = at.set_subtensor( |
| 195 | + truncated_rv[reject_draws], |
| 196 | + new_truncated_rv[reject_draws], |
| 197 | + ) |
| 198 | + reject_draws = at.or_((truncated_rv < lower), (truncated_rv > upper)) |
| 199 | + |
| 200 | + return ( |
| 201 | + (truncated_rv, reject_draws), |
| 202 | + [(rng, next_rng)], |
| 203 | + until(~at.any(reject_draws)), |
| 204 | + ) |
| 205 | + |
| 206 | + (truncated_rv_, reject_draws_), updates = scan( |
| 207 | + loop_fn, |
| 208 | + outputs_info=[ |
| 209 | + at.zeros_like(rv_), |
| 210 | + at.ones_like(rv_, dtype=bool), |
| 211 | + ], |
| 212 | + non_sequences=[lower_, upper_, rng, *rv_inputs_], |
| 213 | + n_steps=max_n_steps, |
| 214 | + strict=True, |
| 215 | + ) |
| 216 | + |
| 217 | + truncated_rv_ = truncated_rv_[-1] |
| 218 | + convergence_ = ~at.any(reject_draws_[-1]) |
| 219 | + truncated_rv_ = TruncationCheck(f"Truncation did not converge in {max_n_steps} steps")( |
| 220 | + truncated_rv_, convergence_ |
| 221 | + ) |
| 222 | + |
| 223 | + return TruncatedRV( |
| 224 | + base_rv_op=dist.owner.op, |
| 225 | + inputs=graph_inputs_, |
| 226 | + outputs=[tuple(updates.values())[0], truncated_rv_], |
| 227 | + ndim_supp=0, |
| 228 | + max_n_steps=max_n_steps, |
| 229 | + )(*graph_inputs) |
| 230 | + |
| 231 | + |
| 232 | +@_change_dist_size.register(TruncatedRV) |
| 233 | +def change_truncated_size(op, dist, new_size, expand): |
| 234 | + *rv_inputs, lower, upper, rng = dist.owner.inputs |
| 235 | + # Recreate the original untruncated RV |
| 236 | + untruncated_rv = op.base_rv_op.make_node(rng, *rv_inputs).default_output() |
| 237 | + if expand: |
| 238 | + new_size = to_tuple(new_size) + tuple(dist.shape) |
| 239 | + |
| 240 | + return Truncated.rv_op( |
| 241 | + untruncated_rv, |
| 242 | + lower=lower, |
| 243 | + upper=upper, |
| 244 | + size=new_size, |
| 245 | + max_n_steps=op.max_n_steps, |
| 246 | + ) |
| 247 | + |
| 248 | + |
| 249 | +@_moment.register(TruncatedRV) |
| 250 | +def truncated_moment(op, rv, *inputs): |
| 251 | + *rv_inputs, lower, upper, rng = inputs |
| 252 | + |
| 253 | + # recreate untruncated rv and respective moment |
| 254 | + untruncated_rv = op.base_rv_op.make_node(rng, *rv_inputs).default_output() |
| 255 | + untruncated_moment = moment(untruncated_rv) |
| 256 | + |
| 257 | + fallback_moment = at.switch( |
| 258 | + at.and_(at.bitwise_not(at.isinf(lower)), at.bitwise_not(at.isinf(upper))), |
| 259 | + (upper - lower) / 2, # lower and upper are finite |
| 260 | + at.switch( |
| 261 | + at.isinf(upper), |
| 262 | + lower + 1, # only lower is finite |
| 263 | + upper - 1, # only upper is finite |
| 264 | + ), |
| 265 | + ) |
| 266 | + |
| 267 | + return at.switch( |
| 268 | + at.and_(at.ge(untruncated_moment, lower), at.le(untruncated_moment, upper)), |
| 269 | + untruncated_moment, # untruncated moment is between lower and upper |
| 270 | + fallback_moment, |
| 271 | + ) |
| 272 | + |
| 273 | + |
| 274 | +@_default_transform.register(TruncatedRV) |
| 275 | +def truncated_default_transform(op, rv): |
| 276 | + # Don't transform discrete truncated distributions |
| 277 | + if op.base_rv_op.dtype.startswith("int"): |
| 278 | + return None |
| 279 | + # Lower and Upper are the arguments -3 and -2 |
| 280 | + return bounded_cont_transform(op, rv, bound_args_indices=(-3, -2)) |
| 281 | + |
| 282 | + |
| 283 | +@_logprob.register(TruncatedRV) |
| 284 | +def truncated_logprob(op, values, *inputs, **kwargs): |
| 285 | + (value,) = values |
| 286 | + |
| 287 | + *rv_inputs, lower, upper, rng = inputs |
| 288 | + rv_inputs = [rng, *rv_inputs] |
| 289 | + |
| 290 | + base_rv_op = op.base_rv_op |
| 291 | + logp = _logprob(base_rv_op, (value,), *rv_inputs, **kwargs) |
| 292 | + # For left truncated RVs, we don't want to include the lower bound in the |
| 293 | + # normalization term |
| 294 | + lower_value = lower - 1 if base_rv_op.dtype.startswith("int") else lower |
| 295 | + lower_logcdf = _logcdf(base_rv_op, lower_value, *rv_inputs, **kwargs) |
| 296 | + upper_logcdf = _logcdf(base_rv_op, upper, *rv_inputs, **kwargs) |
| 297 | + |
| 298 | + if base_rv_op.name: |
| 299 | + logp.name = f"{base_rv_op}_logprob" |
| 300 | + lower_logcdf.name = f"{base_rv_op}_lower_logcdf" |
| 301 | + upper_logcdf.name = f"{base_rv_op}_upper_logcdf" |
| 302 | + |
| 303 | + is_lower_bounded = not (isinstance(lower, TensorConstant) and np.all(np.isneginf(lower.value))) |
| 304 | + is_upper_bounded = not (isinstance(upper, TensorConstant) and np.all(np.isinf(upper.value))) |
| 305 | + |
| 306 | + lognorm = 0 |
| 307 | + if is_lower_bounded and is_upper_bounded: |
| 308 | + lognorm = logdiffexp(upper_logcdf, lower_logcdf) |
| 309 | + elif is_lower_bounded: |
| 310 | + lognorm = at.log1mexp(lower_logcdf) |
| 311 | + elif is_upper_bounded: |
| 312 | + lognorm = upper_logcdf |
| 313 | + |
| 314 | + logp = logp - lognorm |
| 315 | + |
| 316 | + if is_lower_bounded: |
| 317 | + logp = at.switch(value < lower, -np.inf, logp) |
| 318 | + |
| 319 | + if is_upper_bounded: |
| 320 | + logp = at.switch(value <= upper, logp, -np.inf) |
| 321 | + |
| 322 | + if is_lower_bounded and is_upper_bounded: |
| 323 | + logp = check_parameters( |
| 324 | + logp, |
| 325 | + at.le(lower, upper), |
| 326 | + msg="lower_bound <= upper_bound", |
| 327 | + ) |
| 328 | + |
| 329 | + return logp |
| 330 | + |
| 331 | + |
| 332 | +@_truncated.register(NormalRV) |
| 333 | +def _truncated_normal(op, lower, upper, rng, size, dtype, mu, sigma): |
| 334 | + return TruncatedNormal.dist( |
| 335 | + mu=mu, |
| 336 | + sigma=sigma, |
| 337 | + lower=lower, |
| 338 | + upper=upper, |
| 339 | + rng=None, # Do not reuse rng to avoid weird dependencies |
| 340 | + size=size, |
| 341 | + dtype=dtype, |
| 342 | + ) |
0 commit comments