Skip to content

Commit 0b07970

Browse files
authored
fix #4216: str and latex representations for Bound variables (#4217)
* handling of parameters equal to None * improved type juggling of bounds * updating str repr for bounded variables * adding bounded variable to str/_repr_latex test * updating test * black formatting * removing unused import * change string to r-string
1 parent e51b9d3 commit 0b07970

File tree

3 files changed

+31
-8
lines changed

3 files changed

+31
-8
lines changed

pymc3/distributions/bound.py

+23-5
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import numpy as np
1818
import theano.tensor as tt
19-
import theano
2019

2120
from pymc3.distributions.distribution import (
2221
Distribution,
@@ -28,6 +27,8 @@
2827
from pymc3.distributions import transforms
2928
from pymc3.distributions.dist_math import bound
3029

30+
from pymc3.theanof import floatX
31+
3132
__all__ = ["Bound"]
3233

3334

@@ -148,6 +149,25 @@ def random(self, point=None, size=None):
148149
not_broadcast_kwargs={"point": point},
149150
)
150151

152+
def _distr_parameters_for_repr(self):
153+
return ["lower", "upper"]
154+
155+
def _distr_name_for_repr(self):
156+
return "Bound"
157+
158+
def _str_repr(self, **kwargs):
159+
distr_repr = self._wrapped._str_repr(**{**kwargs, "dist": self._wrapped})
160+
if "formatting" in kwargs and kwargs["formatting"] == "latex":
161+
distr_repr = distr_repr[distr_repr.index(r" \sim") + 6 :]
162+
else:
163+
distr_repr = distr_repr[distr_repr.index(" ~") + 3 :]
164+
self_repr = super()._str_repr(**kwargs)
165+
166+
if "formatting" in kwargs and kwargs["formatting"] == "latex":
167+
return self_repr + " -- " + distr_repr
168+
else:
169+
return self_repr + "-" + distr_repr
170+
151171

152172
class _DiscreteBounded(_Bounded, Discrete):
153173
def __init__(self, distribution, lower, upper, transform="infer", *args, **kwargs):
@@ -187,12 +207,10 @@ class _ContinuousBounded(_Bounded, Continuous):
187207
"""
188208

189209
def __init__(self, distribution, lower, upper, transform="infer", *args, **kwargs):
190-
dtype = kwargs.get("dtype", theano.config.floatX)
191-
192210
if lower is not None:
193-
lower = tt.as_tensor_variable(lower).astype(dtype)
211+
lower = tt.as_tensor_variable(floatX(lower))
194212
if upper is not None:
195-
upper = tt.as_tensor_variable(upper).astype(dtype)
213+
upper = tt.as_tensor_variable(floatX(upper))
196214

197215
if transform == "infer":
198216
if lower is None and upper is None:

pymc3/tests/test_distributions.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -1779,16 +1779,20 @@ def setup_class(self):
17791779
# Expected value of outcome
17801780
mu = Deterministic("mu", floatX(alpha + tt.dot(X, b)))
17811781

1782+
# add a bounded variable as well
1783+
bound_var = Bound(Normal, lower=1.0)("bound_var", mu=0, sigma=10)
1784+
17821785
# Likelihood (sampling distribution) of observations
17831786
Y_obs = Normal("Y_obs", mu=mu, sigma=sigma, observed=Y)
1784-
self.distributions = [alpha, sigma, mu, b, Z, Y_obs]
1787+
self.distributions = [alpha, sigma, mu, b, Z, Y_obs, bound_var]
17851788
self.expected_latex = (
17861789
r"$\text{alpha} \sim \text{Normal}(\mathit{mu}=0.0,~\mathit{sigma}=10.0)$",
17871790
r"$\text{sigma} \sim \text{HalfNormal}(\mathit{sigma}=1.0)$",
17881791
r"$\text{mu} \sim \text{Deterministic}(\text{alpha},~\text{Constant},~\text{beta})$",
17891792
r"$\text{beta} \sim \text{Normal}(\mathit{mu}=0.0,~\mathit{sigma}=10.0)$",
17901793
r"$\text{Z} \sim \text{MvNormal}(\mathit{mu}=array,~\mathit{chol_cov}=array)$",
17911794
r"$\text{Y_obs} \sim \text{Normal}(\mathit{mu}=\text{mu},~\mathit{sigma}=f(\text{sigma}))$",
1795+
r"$\text{bound_var} \sim \text{Bound}(\mathit{lower}=1.0,~\mathit{upper}=\text{None})$ -- \text{Normal}(\mathit{mu}=0.0,~\mathit{sigma}=10.0)$",
17921796
)
17931797
self.expected_str = (
17941798
r"alpha ~ Normal(mu=0.0, sigma=10.0)",
@@ -1797,6 +1801,7 @@ def setup_class(self):
17971801
r"beta ~ Normal(mu=0.0, sigma=10.0)",
17981802
r"Z ~ MvNormal(mu=array, chol_cov=array)",
17991803
r"Y_obs ~ Normal(mu=mu, sigma=f(sigma))",
1804+
r"bound_var ~ Bound(lower=1.0, upper=None)-Normal(mu=0.0, sigma=10.0)",
18001805
)
18011806

18021807
def test__repr_latex_(self):

pymc3/util.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,8 @@ def get_default_varnames(var_iterator, include_transformed):
128128

129129
def get_repr_for_variable(variable, formatting="plain"):
130130
"""Build a human-readable string representation for a variable."""
131-
name = variable.name
132-
if name is None:
131+
name = variable.name if variable is not None else None
132+
if name is None and variable is not None:
133133
if hasattr(variable, "get_parents"):
134134
try:
135135
names = [

0 commit comments

Comments
 (0)