Skip to content

Commit 91993d8

Browse files
authored
Fix MatrixNormal.random (#4368)
* Fix MatrixNormal.random * Provided better context for tests * Worked on suggestions * Added a test to signify need of transpose of cholesky matrix. * Used np.swapaxes to take transpose * Given a mention in release notes
1 parent 3cfee77 commit 91993d8

File tree

3 files changed

+62
-24
lines changed

3 files changed

+62
-24
lines changed

Diff for: RELEASE-NOTES.md

+1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ It also brings some dreadfully awaited fixes, so be sure to go through the chang
2727
- `math.logsumexp` now matches `scipy.special.logsumexp` when arrays contain infinite values (see [#4360](https://github.com/pymc-devs/pymc3/pull/4360)).
2828
- Fixed mathematical formulation in `MvStudentT` random method. (see [#4359](https://github.com/pymc-devs/pymc3/pull/4359))
2929
- Fix issue in `logp` method of `HyperGeometric`. It now returns `-inf` for invalid parameters (see [4367](https://github.com/pymc-devs/pymc3/pull/4367))
30+
- Fixed `MatrixNormal` random method to work with parameters as random variables. (see [#4368](https://github.com/pymc-devs/pymc3/pull/4368))
3031

3132
## PyMC3 3.10.0 (7 December 2020)
3233

Diff for: pymc3/distributions/multivariate.py

+14-22
Original file line numberDiff line numberDiff line change
@@ -1445,7 +1445,7 @@ class MatrixNormal(Continuous):
14451445
14461446
.. math::
14471447
f(x \mid \mu, U, V) =
1448-
\frac{1}{(2\pi |U|^n |V|^m)^{1/2}}
1448+
\frac{1}{(2\pi^{m n} |U|^n |V|^m)^{1/2}}
14491449
\exp\left\{
14501450
-\frac{1}{2} \mathrm{Tr}[ V^{-1} (x-\mu)^{\prime} U^{-1} (x-\mu)]
14511451
\right\}
@@ -1637,27 +1637,19 @@ def random(self, point=None, size=None):
16371637
mu, colchol, rowchol = draw_values(
16381638
[self.mu, self.colchol_cov, self.rowchol_cov], point=point, size=size
16391639
)
1640-
if size is None:
1641-
size = ()
1642-
if size in (None, ()):
1643-
standard_normal = np.random.standard_normal((self.shape[0], colchol.shape[-1]))
1644-
samples = mu + np.matmul(rowchol, np.matmul(standard_normal, colchol.T))
1645-
else:
1646-
samples = []
1647-
size = tuple(np.atleast_1d(size))
1648-
if mu.shape == tuple(self.shape):
1649-
for _ in range(np.prod(size)):
1650-
standard_normal = np.random.standard_normal((self.shape[0], colchol.shape[-1]))
1651-
samples.append(mu + np.matmul(rowchol, np.matmul(standard_normal, colchol.T)))
1652-
else:
1653-
for j in range(np.prod(size)):
1654-
standard_normal = np.random.standard_normal(
1655-
(self.shape[0], colchol[j].shape[-1])
1656-
)
1657-
samples.append(
1658-
mu[j] + np.matmul(rowchol[j], np.matmul(standard_normal, colchol[j].T))
1659-
)
1660-
samples = np.array(samples).reshape(size + tuple(self.shape))
1640+
size = to_tuple(size)
1641+
dist_shape = to_tuple(self.shape)
1642+
output_shape = size + dist_shape
1643+
1644+
# Broadcasting all parameters
1645+
(mu,) = broadcast_dist_samples_to(to_shape=output_shape, samples=[mu], size=size)
1646+
rowchol = np.broadcast_to(rowchol, shape=size + rowchol.shape[-2:])
1647+
1648+
colchol = np.broadcast_to(colchol, shape=size + colchol.shape[-2:])
1649+
colchol = np.swapaxes(colchol, -1, -2) # Take transpose
1650+
1651+
standard_normal = np.random.standard_normal(output_shape)
1652+
samples = mu + np.matmul(rowchol, np.matmul(standard_normal, colchol))
16611653
return samples
16621654

16631655
def _trquaddist(self, value):

Diff for: pymc3/tests/test_distributions_random.py

+47-2
Original file line numberDiff line numberDiff line change
@@ -849,6 +849,12 @@ def ref_rand_chol(size, mu, rowchol, colchol):
849849
size, mu, rowcov=np.dot(rowchol, rowchol.T), colcov=np.dot(colchol, colchol.T)
850850
)
851851

852+
def ref_rand_chol_transpose(size, mu, rowchol, colchol):
853+
colchol = colchol.T
854+
return ref_rand(
855+
size, mu, rowcov=np.dot(rowchol, rowchol.T), colcov=np.dot(colchol, colchol.T)
856+
)
857+
852858
def ref_rand_uchol(size, mu, rowchol, colchol):
853859
return ref_rand(
854860
size, mu, rowcov=np.dot(rowchol.T, rowchol), colcov=np.dot(colchol.T, colchol)
@@ -858,7 +864,7 @@ def ref_rand_uchol(size, mu, rowchol, colchol):
858864
pymc3_random(
859865
pm.MatrixNormal,
860866
{"mu": RealMatrix(n, n), "rowcov": PdMatrix(n), "colcov": PdMatrix(n)},
861-
size=n,
867+
size=100,
862868
valuedomain=RealMatrix(n, n),
863869
ref_rand=ref_rand,
864870
)
@@ -867,7 +873,7 @@ def ref_rand_uchol(size, mu, rowchol, colchol):
867873
pymc3_random(
868874
pm.MatrixNormal,
869875
{"mu": RealMatrix(n, n), "rowchol": PdMatrixChol(n), "colchol": PdMatrixChol(n)},
870-
size=n,
876+
size=100,
871877
valuedomain=RealMatrix(n, n),
872878
ref_rand=ref_rand_chol,
873879
)
@@ -878,6 +884,22 @@ def ref_rand_uchol(size, mu, rowchol, colchol):
878884
# extra_args={'lower': False}
879885
# )
880886

887+
# 2 sample test fails because cov becomes different if chol is transposed beforehand.
888+
# This implicity means we need transpose of chol after drawing values in
889+
# MatrixNormal.random method to match stats.matrix_normal.rvs method
890+
with pytest.raises(AssertionError):
891+
pymc3_random(
892+
pm.MatrixNormal,
893+
{
894+
"mu": RealMatrix(n, n),
895+
"rowchol": PdMatrixChol(n),
896+
"colchol": PdMatrixChol(n),
897+
},
898+
size=100,
899+
valuedomain=RealMatrix(n, n),
900+
ref_rand=ref_rand_chol_transpose,
901+
)
902+
881903
def test_kronecker_normal(self):
882904
def ref_rand(size, mu, covs, sigma):
883905
cov = pm.math.kronecker(covs[0], covs[1]).eval()
@@ -1675,3 +1697,26 @@ def test_issue_3706(self):
16751697
prior_pred = pm.sample_prior_predictive(1)
16761698

16771699
assert prior_pred["X"].shape == (1, N, 2)
1700+
1701+
1702+
def test_matrix_normal_random_with_random_variables():
1703+
"""
1704+
This test checks for shape correctness when using MatrixNormal distribution
1705+
with parameters as random variables.
1706+
Originally reported - https://github.com/pymc-devs/pymc3/issues/3585
1707+
"""
1708+
K = 3
1709+
D = 15
1710+
mu_0 = np.zeros((D, K))
1711+
lambd = 1.0
1712+
with pm.Model() as model:
1713+
sd_dist = pm.HalfCauchy.dist(beta=2.5)
1714+
packedL = pm.LKJCholeskyCov("packedL", eta=2, n=D, sd_dist=sd_dist)
1715+
L = pm.expand_packed_triangular(D, packedL, lower=True)
1716+
Sigma = pm.Deterministic("Sigma", L.dot(L.T)) # D x D covariance
1717+
mu = pm.MatrixNormal(
1718+
"mu", mu=mu_0, rowcov=(1 / lambd) * Sigma, colcov=np.eye(K), shape=(D, K)
1719+
)
1720+
prior = pm.sample_prior_predictive(2)
1721+
1722+
assert prior["mu"].shape == (2, D, K)

0 commit comments

Comments
 (0)