Skip to content

Commit a1f9d00

Browse files
committed
Implement Truncated distributions
1 parent ec27b5c commit a1f9d00

File tree

5 files changed

+649
-0
lines changed

5 files changed

+649
-0
lines changed

.github/workflows/tests.yml

+1
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ jobs:
5555
pymc/tests/distributions/test_continuous.py
5656
pymc/tests/distributions/test_multivariate.py
5757
pymc/tests/distributions/test_simulator.py
58+
pymc/tests/distributions/test_truncated.py
5859
5960
- |
6061
pymc/tests/tuning/test_scaling.py

pymc/distributions/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@
110110
MvStudentTRandomWalk,
111111
RandomWalk,
112112
)
113+
from pymc.distributions.truncated import Truncated
113114

114115
__all__ = [
115116
"Uniform",
@@ -192,6 +193,7 @@
192193
"Rice",
193194
"Moyal",
194195
"Simulator",
196+
"Truncated",
195197
"Censored",
196198
"CAR",
197199
"PolyaGamma",

pymc/distributions/truncated.py

+342
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,342 @@
1+
from functools import singledispatch
2+
3+
import aesara
4+
import aesara.tensor as at
5+
import numpy as np
6+
7+
from aeppl.abstract import MeasurableVariable
8+
from aeppl.logprob import _logcdf, _logprob, icdf, logcdf
9+
from aesara import scan
10+
from aesara.graph import Op
11+
from aesara.graph.basic import Node
12+
from aesara.raise_op import CheckAndRaise
13+
from aesara.scan import until
14+
from aesara.tensor import TensorConstant, TensorVariable
15+
from aesara.tensor.random.basic import NormalRV
16+
from aesara.tensor.random.op import RandomVariable
17+
18+
from pymc.distributions.continuous import TruncatedNormal, bounded_cont_transform
19+
from pymc.distributions.dist_math import check_parameters
20+
from pymc.distributions.distribution import (
21+
Distribution,
22+
SymbolicRandomVariable,
23+
_moment,
24+
moment,
25+
)
26+
from pymc.distributions.shape_utils import _change_dist_size, change_dist_size, to_tuple
27+
from pymc.distributions.transforms import _default_transform
28+
from pymc.exceptions import TruncationError
29+
from pymc.math import logdiffexp
30+
from pymc.util import check_dist_not_registered
31+
32+
33+
class TruncatedRV(SymbolicRandomVariable):
34+
"""An `Op` constructed from an Aesara graph that represents a truncated univariate
35+
random variable."""
36+
37+
default_output = 1
38+
base_rv_op = None
39+
max_n_steps = None
40+
41+
def __init__(self, *args, base_rv_op: Op, max_n_steps: int, **kwargs):
42+
self.base_rv_op = base_rv_op
43+
self.max_n_steps = max_n_steps
44+
super().__init__(*args, **kwargs)
45+
46+
def update(self, node: Node):
47+
"""Return the update mapping for the noise RV."""
48+
# Since RNG is a shared variable it shows up as the last node input
49+
return {node.inputs[-1]: node.outputs[0]}
50+
51+
52+
MeasurableVariable.register(TruncatedRV)
53+
54+
55+
@singledispatch
56+
def _truncated(op: Op, lower, upper, *params):
57+
"""Return the truncated equivalent of another `RandomVariable`."""
58+
raise NotImplementedError(f"{op} does not have an equivalent truncated version implemented")
59+
60+
61+
class TruncationCheck(CheckAndRaise):
62+
"""Implements a check in truncated graphs.
63+
Raises `TruncationError` if the check is not True.
64+
"""
65+
66+
def __init__(self, msg=""):
67+
super().__init__(TruncationError, msg)
68+
69+
def __str__(self):
70+
return f"TruncationCheck{{{self.msg}}}"
71+
72+
73+
class Truncated(Distribution):
74+
r"""
75+
Truncated distribution
76+
77+
The pdf of a censored distribution is
78+
79+
.. math::
80+
81+
\begin{cases}
82+
0 & \text{for } x < lower, \\
83+
\frac{\text{PDF}(x, dist)}{\text{CDF}(upper, dist) - \text{CDF}(lower, dist)}
84+
& \text{for } lower <= x <= upper, \\
85+
0 & \text{for } x > upper,
86+
\end{cases}
87+
88+
89+
Parameters
90+
----------
91+
dist: unnamed distribution
92+
Univariate distribution created via the `.dist()` API, which will be truncated.
93+
This distribution must be a pure RandomVariable and have a logcdf method
94+
implemented for MCMC sampling.
95+
96+
.. warning:: dist will be cloned, rendering it independent of the one passed as input.
97+
98+
lower: tensor_like of float or None
99+
Lower (left) truncation point. If `None` the distribution will not be left truncated.
100+
upper: tensor_like of float or None
101+
Upper (right) truncation point. If `None`, the distribution will not be right truncated.
102+
max_n_steps: int, defaults 10_000
103+
Maximum number of resamples that are attempted when performing rejection sampling.
104+
A `TruncationError` is raised if convergence is not reached after that many steps.
105+
106+
Returns
107+
-------
108+
truncated_distribution: TensorVariable
109+
Graph representing a truncated `RandomVariable`. A specialized `Op` may be used
110+
if the `Op` of the dist has a dispatched `_truncated` function. Otherwise, a
111+
`SymbolicRandomVariable` graph representing the truncation process, via inverse
112+
CDF sampling (if the underlying dist has a logcdf method), or rejection sampling
113+
is returned.
114+
"""
115+
116+
rv_type = TruncatedRV
117+
118+
@classmethod
119+
def dist(cls, dist, lower=None, upper=None, max_n_steps: int = 10_000, **kwargs):
120+
if not (isinstance(dist, TensorVariable) and isinstance(dist.owner.op, RandomVariable)):
121+
if isinstance(dist.owner.op, SymbolicRandomVariable):
122+
raise NotImplementedError(
123+
f"Truncation not implemented for SymbolicRandomVariable {dist.owner.op}"
124+
)
125+
raise ValueError(
126+
f"Truncation dist must be a distribution created via the `.dist()` API, got {type(dist)}"
127+
)
128+
129+
if dist.owner.op.ndim_supp > 0:
130+
raise NotImplementedError("Truncation not implemented for multivariate distributions")
131+
132+
check_dist_not_registered(dist)
133+
134+
if lower is None and upper is None:
135+
raise ValueError("lower and upper cannot both be None")
136+
137+
return super().dist([dist, lower, upper, max_n_steps], **kwargs)
138+
139+
@classmethod
140+
def rv_op(cls, dist, lower, upper, max_n_steps, size=None):
141+
142+
# Try to use specialized Op
143+
try:
144+
return _truncated(dist.owner.op, lower, upper, *dist.owner.inputs)
145+
except NotImplementedError:
146+
pass
147+
148+
lower = at.as_tensor_variable(lower) if lower is not None else at.constant(-np.inf)
149+
upper = at.as_tensor_variable(upper) if upper is not None else at.constant(np.inf)
150+
151+
if size is None:
152+
size = at.broadcast_shape(dist, lower, upper)
153+
dist = change_dist_size(dist, new_size=size)
154+
155+
# Variables with `_` suffix identify dummy inputs for the OpFromGraph
156+
graph_inputs = [*dist.owner.inputs[1:], lower, upper]
157+
graph_inputs_ = [inp.type() for inp in graph_inputs]
158+
*rv_inputs_, lower_, upper_ = graph_inputs_
159+
160+
# We will use a Shared RNG variable because Scan demands it, even though it
161+
# would not be necessary for the OpFromGraph inverse cdf.
162+
rng = aesara.shared(np.random.default_rng())
163+
rv_ = dist.owner.op.make_node(rng, *rv_inputs_).default_output()
164+
165+
# Try to use inverted cdf sampling
166+
try:
167+
# For left truncated discrete RVs, we need to include the whole lower bound.
168+
# This may result in draws below the truncation range, if any uniform == 0
169+
lower_value = lower_ - 1 if dist.owner.op.dtype.startswith("int") else lower_
170+
cdf_lower_ = at.exp(logcdf(rv_, lower_value))
171+
cdf_upper_ = at.exp(logcdf(rv_, upper_))
172+
# It's okay to reuse the same rng here, because the rng in rv_ will not be
173+
# used by either the logcdf of icdf functions
174+
uniform_ = at.random.uniform(
175+
cdf_lower_,
176+
cdf_upper_,
177+
rng=rng,
178+
size=rv_inputs_[0],
179+
)
180+
truncated_rv_ = icdf(rv_, uniform_)
181+
return TruncatedRV(
182+
base_rv_op=dist.owner.op,
183+
inputs=graph_inputs_,
184+
outputs=[uniform_.owner.outputs[0], truncated_rv_],
185+
ndim_supp=0,
186+
max_n_steps=max_n_steps,
187+
)(*graph_inputs)
188+
except NotImplementedError:
189+
pass
190+
191+
# Fallback to rejection sampling
192+
def loop_fn(truncated_rv, reject_draws, lower, upper, rng, *rv_inputs):
193+
next_rng, new_truncated_rv = dist.owner.op.make_node(rng, *rv_inputs).outputs
194+
truncated_rv = at.set_subtensor(
195+
truncated_rv[reject_draws],
196+
new_truncated_rv[reject_draws],
197+
)
198+
reject_draws = at.or_((truncated_rv < lower), (truncated_rv > upper))
199+
200+
return (
201+
(truncated_rv, reject_draws),
202+
[(rng, next_rng)],
203+
until(~at.any(reject_draws)),
204+
)
205+
206+
(truncated_rv_, reject_draws_), updates = scan(
207+
loop_fn,
208+
outputs_info=[
209+
at.zeros_like(rv_),
210+
at.ones_like(rv_, dtype=bool),
211+
],
212+
non_sequences=[lower_, upper_, rng, *rv_inputs_],
213+
n_steps=max_n_steps,
214+
strict=True,
215+
)
216+
217+
truncated_rv_ = truncated_rv_[-1]
218+
convergence_ = ~at.any(reject_draws_[-1])
219+
truncated_rv_ = TruncationCheck(f"Truncation did not converge in {max_n_steps} steps")(
220+
truncated_rv_, convergence_
221+
)
222+
223+
return TruncatedRV(
224+
base_rv_op=dist.owner.op,
225+
inputs=graph_inputs_,
226+
outputs=[tuple(updates.values())[0], truncated_rv_],
227+
ndim_supp=0,
228+
max_n_steps=max_n_steps,
229+
)(*graph_inputs)
230+
231+
232+
@_change_dist_size.register(TruncatedRV)
233+
def change_truncated_size(op, dist, new_size, expand):
234+
*rv_inputs, lower, upper, rng = dist.owner.inputs
235+
# Recreate the original untruncated RV
236+
untruncated_rv = op.base_rv_op.make_node(rng, *rv_inputs).default_output()
237+
if expand:
238+
new_size = to_tuple(new_size) + tuple(dist.shape)
239+
240+
return Truncated.rv_op(
241+
untruncated_rv,
242+
lower=lower,
243+
upper=upper,
244+
size=new_size,
245+
max_n_steps=op.max_n_steps,
246+
)
247+
248+
249+
@_moment.register(TruncatedRV)
250+
def truncated_moment(op, rv, *inputs):
251+
*rv_inputs, lower, upper, rng = inputs
252+
253+
# recreate untruncated rv and respective moment
254+
untruncated_rv = op.base_rv_op.make_node(rng, *rv_inputs).default_output()
255+
untruncated_moment = moment(untruncated_rv)
256+
257+
fallback_moment = at.switch(
258+
at.and_(at.bitwise_not(at.isinf(lower)), at.bitwise_not(at.isinf(upper))),
259+
(upper - lower) / 2, # lower and upper are finite
260+
at.switch(
261+
at.isinf(upper),
262+
lower + 1, # only lower is finite
263+
upper - 1, # only upper is finite
264+
),
265+
)
266+
267+
return at.switch(
268+
at.and_(at.ge(untruncated_moment, lower), at.le(untruncated_moment, upper)),
269+
untruncated_moment, # untruncated moment is between lower and upper
270+
fallback_moment,
271+
)
272+
273+
274+
@_default_transform.register(TruncatedRV)
275+
def truncated_default_transform(op, rv):
276+
# Don't transform discrete truncated distributions
277+
if op.base_rv_op.dtype.startswith("int"):
278+
return None
279+
# Lower and Upper are the arguments -3 and -2
280+
return bounded_cont_transform(op, rv, bound_args_indices=(-3, -2))
281+
282+
283+
@_logprob.register(TruncatedRV)
284+
def truncated_logprob(op, values, *inputs, **kwargs):
285+
(value,) = values
286+
287+
*rv_inputs, lower, upper, rng = inputs
288+
rv_inputs = [rng, *rv_inputs]
289+
290+
base_rv_op = op.base_rv_op
291+
logp = _logprob(base_rv_op, (value,), *rv_inputs, **kwargs)
292+
# For left truncated RVs, we don't want to include the lower bound in the
293+
# normalization term
294+
lower_value = lower - 1 if base_rv_op.dtype.startswith("int") else lower
295+
lower_logcdf = _logcdf(base_rv_op, lower_value, *rv_inputs, **kwargs)
296+
upper_logcdf = _logcdf(base_rv_op, upper, *rv_inputs, **kwargs)
297+
298+
if base_rv_op.name:
299+
logp.name = f"{base_rv_op}_logprob"
300+
lower_logcdf.name = f"{base_rv_op}_lower_logcdf"
301+
upper_logcdf.name = f"{base_rv_op}_upper_logcdf"
302+
303+
is_lower_bounded = not (isinstance(lower, TensorConstant) and np.all(np.isneginf(lower.value)))
304+
is_upper_bounded = not (isinstance(upper, TensorConstant) and np.all(np.isinf(upper.value)))
305+
306+
lognorm = 0
307+
if is_lower_bounded and is_upper_bounded:
308+
lognorm = logdiffexp(upper_logcdf, lower_logcdf)
309+
elif is_lower_bounded:
310+
lognorm = at.log1mexp(lower_logcdf)
311+
elif is_upper_bounded:
312+
lognorm = upper_logcdf
313+
314+
logp = logp - lognorm
315+
316+
if is_lower_bounded:
317+
logp = at.switch(value < lower, -np.inf, logp)
318+
319+
if is_upper_bounded:
320+
logp = at.switch(value <= upper, logp, -np.inf)
321+
322+
if is_lower_bounded and is_upper_bounded:
323+
logp = check_parameters(
324+
logp,
325+
at.le(lower, upper),
326+
msg="lower_bound <= upper_bound",
327+
)
328+
329+
return logp
330+
331+
332+
@_truncated.register(NormalRV)
333+
def _truncated_normal(op, lower, upper, rng, size, dtype, mu, sigma):
334+
return TruncatedNormal.dist(
335+
mu=mu,
336+
sigma=sigma,
337+
lower=lower,
338+
upper=upper,
339+
rng=None, # Do not reuse rng to avoid weird dependencies
340+
size=size,
341+
dtype=dtype,
342+
)

pymc/exceptions.py

+4
Original file line numberDiff line numberDiff line change
@@ -74,3 +74,7 @@ def __init__(self, message, actual=None, expected=None):
7474
super().__init__(f"{message} (expected {expected})")
7575
else:
7676
super().__init__(message)
77+
78+
79+
class TruncationError(Exception):
80+
"""Exception for errors generated from truncated graphs"""

0 commit comments

Comments
 (0)