Skip to content

Commit 9e4fcdb

Browse files
committed
Create helper wrapper around Aeppl IntervalTransform
1 parent 92a5866 commit 9e4fcdb

File tree

7 files changed

+119
-36
lines changed

7 files changed

+119
-36
lines changed

Diff for: docs/source/api/distributions/transforms.rst

+11-9
Original file line numberDiff line numberDiff line change
@@ -15,28 +15,30 @@ Transform instances are the entities that should be used in the
1515

1616
simplex
1717
logodds
18-
interval
1918
log_exp_m1
2019
ordered
2120
log
2221
sum_to_1
2322
circular
2423

25-
Transform Composition Classes
26-
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
27-
28-
.. autosummary::
29-
:toctree: generated
30-
31-
Chain
32-
CholeskyCovPacked
3324

3425
Specific Transform Classes
3526
~~~~~~~~~~~~~~~~~~~~~~~~~~
3627

3728
.. autosummary::
3829
:toctree: generated
3930

31+
CholeskyCovPacked
32+
Interval
4033
LogExpM1
4134
Ordered
4235
SumTo1
36+
37+
38+
Transform Composition Classes
39+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
40+
41+
.. autosummary::
42+
:toctree: generated
43+
44+
Chain

Diff for: pymc/distributions/continuous.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def transform_params(*args):
195195

196196
return lower, upper
197197

198-
return transforms.interval(transform_params)
198+
return transforms.Interval(bounds_fn=transform_params)
199199

200200

201201
def assert_negative_support(var, label, distname, value=-1e-6):
@@ -3796,7 +3796,7 @@ def transform_params(*params):
37963796
_, _, _, x_points, _, _ = params
37973797
return floatX(x_points[0]), floatX(x_points[-1])
37983798

3799-
kwargs["transform"] = transforms.interval(transform_params)
3799+
kwargs["transform"] = transforms.Interval(bounds_fn=transform_params)
38003800
return super().__new__(cls, *args, **kwargs)
38013801

38023802
@classmethod

Diff for: pymc/distributions/multivariate.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
rv_size_is_none,
6363
to_tuple,
6464
)
65-
from pymc.distributions.transforms import interval
65+
from pymc.distributions.transforms import Interval
6666
from pymc.math import kron_diag, kron_dot
6767
from pymc.util import UNSET, check_dist_not_registered
6868

@@ -1554,7 +1554,7 @@ class LKJCorr(BoundedContinuous):
15541554
def __new__(cls, *args, **kwargs):
15551555
transform = kwargs.get("transform", UNSET)
15561556
if transform is UNSET:
1557-
kwargs["transform"] = interval(lambda *args: (floatX(-1.0), floatX(1.0)))
1557+
kwargs["transform"] = Interval(floatX(-1.0), floatX(1.0))
15581558
return super().__new__(cls, *args, **kwargs)
15591559

15601560
@classmethod

Diff for: pymc/distributions/transforms.py

+83-5
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import aesara.tensor as at
16+
import numpy as np
1617

1718
from aeppl.transforms import (
1819
CircularTransform,
@@ -27,7 +28,7 @@
2728
"RVTransform",
2829
"simplex",
2930
"logodds",
30-
"interval",
31+
"Interval",
3132
"log_exp_m1",
3233
"ordered",
3334
"log",
@@ -174,10 +175,87 @@ def log_jac_det(self, value, *inputs):
174175
Instantiation of :class:`aeppl.transforms.LogOddsTransform`
175176
for use in the ``transform`` argument of a random variable."""
176177

177-
interval = IntervalTransform
178-
interval.__doc__ = """
179-
Instantiation of :class:`aeppl.transforms.IntervalTransform`
180-
for use in the ``transform`` argument of a random variable."""
178+
179+
class Interval(IntervalTransform):
180+
"""Wrapper around :class:`aeppl.transforms.IntervalTransform` for use in the
181+
``transform`` argument of a random variable.
182+
183+
Parameters
184+
----------
185+
lower : int, float, or None
186+
Lower bound of the interval transform. Must be a constant value. If ``None``, the
187+
interval is not bounded below.
188+
upper : int, float or None
189+
Upper bound of the interval transfrom. Must be a finite value. If ``None``, the
190+
interval is not bounded above.
191+
bounds_fn : callable
192+
Alternative to lower and upper. Must return a tuple of lower and upper bounds
193+
as a symbolic function of the respective distribution inputs. If lower or
194+
upper is ``None``, the interval is unbounded on that edge.
195+
196+
.. warning:: Expressions returned by `bounds_fn` should depend only on the
197+
distribution inputs or other constants. Expressions that depend on other
198+
symbolic variables, including nonlocal variables defined in the model
199+
context will likely break sampling.
200+
201+
202+
Examples
203+
--------
204+
.. code-block:: python
205+
206+
# Create an interval transform between -1 and +1
207+
with pm.Model():
208+
interval = pm.distributions.transforms.Interval(lower=-1, upper=1)
209+
x = pm.Normal("x", transform=interval)
210+
211+
.. code-block:: python
212+
213+
# Create an interval transform between -1 and +1 using a callable
214+
def get_bounds(rng, size, dtype, loc, scale):
215+
return 0, None
216+
217+
with pm.Model():
218+
interval = pm.distributions.transforms.Interval(bouns_fn=get_bounds)
219+
x = pm.Normal("x", transform=interval)
220+
221+
.. code-block:: python
222+
223+
# Create a lower bounded interval transform based on a distribution parameter
224+
def get_bounds(rng, size, dtype, loc, scale):
225+
return loc, None
226+
227+
interval = pm.distributions.transforms.Interval(bounds_fn=get_bounds)
228+
229+
with pm.Model():
230+
loc = pm.Normal("loc")
231+
x = pm.Normal("x", mu=loc, sigma=2, transform=interval)
232+
"""
233+
234+
def __init__(self, lower=None, upper=None, *, bounds_fn=None):
235+
if bounds_fn is None:
236+
try:
237+
bounds = tuple(
238+
None if bound is None else at.constant(bound, ndim=0).data
239+
for bound in (lower, upper)
240+
)
241+
except (ValueError, TypeError):
242+
raise ValueError(
243+
"Interval bounds must be constant values. If you need expressions that "
244+
"depend on symbolic variables use `args_fn`"
245+
)
246+
247+
lower, upper = (
248+
None if (bound is None or np.isinf(bound)) else bound for bound in bounds
249+
)
250+
251+
if lower is None and upper is None:
252+
raise ValueError("Lower and upper interval bounds cannot both be None")
253+
254+
def bounds_fn(*rv_inputs):
255+
return lower, upper
256+
257+
super().__init__(args_fn=bounds_fn)
258+
181259

182260
log_exp_m1 = LogExpM1()
183261
log_exp_m1.__doc__ = """

Diff for: pymc/tests/test_distributions.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2710,7 +2710,7 @@ def test_arguments_checks(self):
27102710
with pm.Model() as m:
27112711
x = pm.Poisson.dist(0.5)
27122712
with pytest.raises(ValueError, match=msg):
2713-
pm.Bound("bound", x, transform=pm.transforms.interval)
2713+
pm.Bound("bound", x, transform=pm.distributions.transforms.log)
27142714

27152715
msg = "Given dims do not exist in model coordinates."
27162716
with pm.Model() as m:

Diff for: pymc/tests/test_sampling.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -327,11 +327,13 @@ def test_deterministic_of_unobserved(self):
327327

328328
np.testing.assert_allclose(idata.posterior["y"], idata.posterior["x"] + 100)
329329

330-
def test_transform_with_rv_depenency(self):
330+
def test_transform_with_rv_dependency(self):
331331
# Test that untransformed variables that depend on upstream variables are properly handled
332332
with pm.Model() as m:
333333
x = pm.HalfNormal("x", observed=1)
334-
transform = pm.transforms.IntervalTransform(lambda *inputs: (inputs[-2], inputs[-1]))
334+
transform = pm.distributions.transforms.Interval(
335+
bounds_fn=lambda *inputs: (inputs[-2], inputs[-1])
336+
)
335337
y = pm.Uniform("y", lower=0, upper=x, transform=transform)
336338
trace = pm.sample(tune=10, draws=50, return_inferencedata=False, random_seed=336)
337339

Diff for: pymc/tests/test_transforms.py

+16-15
Original file line numberDiff line numberDiff line change
@@ -177,10 +177,7 @@ def test_logodds():
177177

178178

179179
def test_lowerbound():
180-
def transform_params(*inputs):
181-
return 0.0, None
182-
183-
trans = tr.interval(transform_params)
180+
trans = tr.Interval(0.0, None)
184181
check_transform(trans, Rplusbig)
185182

186183
check_jacobian_det(trans, Rplusbig, elemwise=True)
@@ -191,10 +188,7 @@ def transform_params(*inputs):
191188

192189

193190
def test_upperbound():
194-
def transform_params(*inputs):
195-
return None, 0.0
196-
197-
trans = tr.interval(transform_params)
191+
trans = tr.Interval(None, 0.0)
198192
check_transform(trans, Rminusbig)
199193

200194
check_jacobian_det(trans, Rminusbig, elemwise=True)
@@ -208,10 +202,7 @@ def test_interval():
208202
for a, b in [(-4, 5.5), (0.1, 0.7), (-10, 4.3)]:
209203
domain = Unit * np.float64(b - a) + np.float64(a)
210204

211-
def transform_params(z=a, y=b):
212-
return z, y
213-
214-
trans = tr.interval(transform_params)
205+
trans = tr.Interval(a, b)
215206
check_transform(trans, domain)
216207

217208
check_jacobian_det(trans, domain, elemwise=True)
@@ -375,7 +366,7 @@ def transform_params(*inputs):
375366
upper = at.as_tensor_variable(upper) if upper is not None else None
376367
return lower, upper
377368

378-
interval = tr.interval(transform_params)
369+
interval = tr.Interval(bounds_fn=transform_params)
379370
model = self.build_model(
380371
pm.Uniform, {"lower": lower, "upper": upper}, size=size, transform=interval
381372
)
@@ -396,7 +387,7 @@ def transform_params(*inputs):
396387
upper = at.as_tensor_variable(upper) if upper is not None else None
397388
return lower, upper
398389

399-
interval = tr.interval(transform_params)
390+
interval = tr.Interval(bounds_fn=transform_params)
400391
model = self.build_model(
401392
pm.Triangular, {"lower": lower, "c": c, "upper": upper}, size=size, transform=interval
402393
)
@@ -491,7 +482,7 @@ def transform_params(*inputs):
491482
upper = at.as_tensor_variable(upper) if upper is not None else None
492483
return lower, upper
493484

494-
interval = tr.interval(transform_params)
485+
interval = tr.Interval(bounds_fn=transform_params)
495486

496487
initval = np.sort(np.abs(np.random.rand(*size)))
497488
model = self.build_model(
@@ -556,3 +547,13 @@ def test_triangular_transform():
556547
transform = x.tag.value_var.tag.transform
557548
assert np.isclose(transform.backward(-np.inf, *x.owner.inputs).eval(), 0)
558549
assert np.isclose(transform.backward(np.inf, *x.owner.inputs).eval(), 2)
550+
551+
552+
def test_interval_transform_raises():
553+
with pytest.raises(ValueError, match="Lower and upper interval bounds cannot both be None"):
554+
tr.Interval(None, None)
555+
556+
with pytest.raises(ValueError, match="Interval bounds must be constant values"):
557+
tr.Interval(at.constant(5) + 1, None)
558+
559+
assert tr.Interval(at.constant(5), None)

0 commit comments

Comments
 (0)