Skip to content

Commit a3a63da

Browse files
authored
Fix str representations for KroneckerNormal and MatrixNormal (#4243)
* fallback __str__ to default Theano on error * fix str repr for KroneckerNormal and MatrixNormal * black formatting * update release notes
1 parent 68d5201 commit a3a63da

File tree

5 files changed

+38
-3
lines changed

5 files changed

+38
-3
lines changed

Diff for: RELEASE-NOTES.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
- Add alternative parametrization to NegativeBinomial distribution in terms of n and p (see [#4126](https://github.com/pymc-devs/pymc3/issues/4126))
2424
- Add Bayesian Additive Regression Trees (BARTs) [#4183](https://github.com/pymc-devs/pymc3/pull/4183))
2525
- Added a new `MixtureSameFamily` distribution to handle mixtures of arbitrary dimensions in vectorized form (see [#4185](https://github.com/pymc-devs/pymc3/issues/4185)).
26-
- Added semantically meaningful `str` representations to PyMC3 objects for console, notebook, and GraphViz use (see [#4076](https://github.com/pymc-devs/pymc3/pull/4076), [#4065](https://github.com/pymc-devs/pymc3/pull/4065), [#4159](https://github.com/pymc-devs/pymc3/pull/4159), and [#4217](https://github.com/pymc-devs/pymc3/pull/4217))
26+
- Added semantically meaningful `str` representations to PyMC3 objects for console, notebook, and GraphViz use (see [#4076](https://github.com/pymc-devs/pymc3/pull/4076), [#4065](https://github.com/pymc-devs/pymc3/pull/4065), [#4159](https://github.com/pymc-devs/pymc3/pull/4159), [#4217](https://github.com/pymc-devs/pymc3/pull/4217), and [#4243](https://github.com/pymc-devs/pymc3/pull/4243)).
2727

2828

2929

Diff for: pymc3/distributions/distribution.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,10 @@ def _str_repr(self, name=None, dist=None, formatting="plain"):
194194
)
195195

196196
def __str__(self, **kwargs):
197-
return self._str_repr(formatting="plain", **kwargs)
197+
try:
198+
return self._str_repr(formatting="plain", **kwargs)
199+
except:
200+
return super().__str__()
198201

199202
def _repr_latex_(self, **kwargs):
200203
"""Magic method name for IPython to use for LaTeX formatting."""

Diff for: pymc3/distributions/multivariate.py

+10
Original file line numberDiff line numberDiff line change
@@ -1449,6 +1449,9 @@ def logp(self, x):
14491449
broadcast_conditions=False,
14501450
)
14511451

1452+
def _distr_parameters_for_repr(self):
1453+
return ["eta", "n"]
1454+
14521455

14531456
class MatrixNormal(Continuous):
14541457
R"""
@@ -1712,6 +1715,10 @@ def logp(self, value):
17121715
norm = -0.5 * m * n * pm.floatX(np.log(2 * np.pi))
17131716
return norm - 0.5 * trquaddist - m * half_collogdet - n * half_rowlogdet
17141717

1718+
def _distr_parameters_for_repr(self):
1719+
mapping = {"tau": "tau", "cov": "cov", "chol": "chol_cov"}
1720+
return ["mu", "row" + mapping[self._rowcov_type], "col" + mapping[self._colcov_type]]
1721+
17151722

17161723
class KroneckerNormal(Continuous):
17171724
R"""
@@ -1954,3 +1961,6 @@ def logp(self, value):
19541961
"""
19551962
quad, logdet = self._quaddist(value)
19561963
return -(quad + logdet + self.N * tt.log(2 * np.pi)) / 2.0
1964+
1965+
def _distr_parameters_for_repr(self):
1966+
return ["mu"]

Diff for: pymc3/model.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,10 @@ def _repr_latex_(self, **kwargs):
8080
return self._str_repr(formatting="latex", **kwargs)
8181

8282
def __str__(self, **kwargs):
83-
return self._str_repr(formatting="plain", **kwargs)
83+
try:
84+
return self._str_repr(formatting="plain", **kwargs)
85+
except:
86+
return super().__str__()
8487

8588
__latex__ = _repr_latex_
8689

Diff for: pymc3/tests/test_distributions.py

+19
Original file line numberDiff line numberDiff line change
@@ -1782,8 +1782,23 @@ def setup_class(self):
17821782
# add a bounded variable as well
17831783
bound_var = Bound(Normal, lower=1.0)("bound_var", mu=0, sigma=10)
17841784

1785+
# KroneckerNormal
1786+
n, m = 3, 4
1787+
covs = [np.eye(n), np.eye(m)]
1788+
kron_normal = KroneckerNormal("kron_normal", mu=np.zeros(n * m), covs=covs, shape=n * m)
1789+
1790+
# MatrixNormal
1791+
matrix_normal = MatrixNormal(
1792+
"mat_normal",
1793+
mu=np.random.normal(size=n),
1794+
rowcov=np.eye(n),
1795+
colchol=np.linalg.cholesky(np.eye(n)),
1796+
shape=(n, n),
1797+
)
1798+
17851799
# Likelihood (sampling distribution) of observations
17861800
Y_obs = Normal("Y_obs", mu=mu, sigma=sigma, observed=Y)
1801+
17871802
self.distributions = [alpha, sigma, mu, b, Z, Y_obs, bound_var]
17881803
self.expected_latex = (
17891804
r"$\text{alpha} \sim \text{Normal}(\mathit{mu}=0.0,~\mathit{sigma}=10.0)$",
@@ -1793,6 +1808,8 @@ def setup_class(self):
17931808
r"$\text{Z} \sim \text{MvNormal}(\mathit{mu}=array,~\mathit{chol_cov}=array)$",
17941809
r"$\text{Y_obs} \sim \text{Normal}(\mathit{mu}=\text{mu},~\mathit{sigma}=f(\text{sigma}))$",
17951810
r"$\text{bound_var} \sim \text{Bound}(\mathit{lower}=1.0,~\mathit{upper}=\text{None})$ -- \text{Normal}(\mathit{mu}=0.0,~\mathit{sigma}=10.0)$",
1811+
r"$\text{kron_normal} \sim \text{KroneckerNormal}(\mathit{mu}=array)$",
1812+
r"$\text{mat_normal} \sim \text{MatrixNormal}(\mathit{mu}=array,~\mathit{rowcov}=array,~\mathit{colchol_cov}=array)$",
17961813
)
17971814
self.expected_str = (
17981815
r"alpha ~ Normal(mu=0.0, sigma=10.0)",
@@ -1802,6 +1819,8 @@ def setup_class(self):
18021819
r"Z ~ MvNormal(mu=array, chol_cov=array)",
18031820
r"Y_obs ~ Normal(mu=mu, sigma=f(sigma))",
18041821
r"bound_var ~ Bound(lower=1.0, upper=None)-Normal(mu=0.0, sigma=10.0)",
1822+
r"kron_normal ~ KroneckerNormal(mu=array)",
1823+
r"mat_normal ~ MatrixNormal(mu=array, rowcov=array, colchol_cov=array)",
18051824
)
18061825

18071826
def test__repr_latex_(self):

0 commit comments

Comments
 (0)