Skip to content

Commit 657eb2a

Browse files
committed
Allow for scalar or size 1 mu in MvNormal and MvStudentT
1 parent e0ea364 commit 657eb2a

File tree

4 files changed

+55
-3
lines changed

4 files changed

+55
-3
lines changed

pymc/distributions/multivariate.py

+4
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,8 @@ class MvNormal(Continuous):
233233
def dist(cls, mu, cov=None, tau=None, chol=None, lower=True, **kwargs):
234234
mu = at.as_tensor_variable(mu)
235235
cov = quaddist_matrix(cov, chol, tau, lower)
236+
# Aesara is stricter about the shape of mu, than PyMC used to be
237+
mu = at.broadcast_arrays(mu, cov[..., -1])[0]
236238
return super().dist([mu, cov], **kwargs)
237239

238240
def get_moment(rv, size, mu, cov):
@@ -362,6 +364,8 @@ def dist(cls, nu, Sigma=None, mu=None, cov=None, tau=None, chol=None, lower=True
362364
nu = at.as_tensor_variable(floatX(nu))
363365
mu = at.as_tensor_variable(floatX(mu))
364366
cov = quaddist_matrix(cov, chol, tau, lower)
367+
# Aesara is stricter about the shape of mu, than PyMC used to be
368+
mu = at.broadcast_arrays(mu, cov[..., -1])[0]
365369
assert_negative_support(nu, "nu", "MvStudentT")
366370
return super().dist([nu, mu, cov], **kwargs)
367371

pymc/tests/test_distributions.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -2945,7 +2945,7 @@ def setup_class(self):
29452945
r"sigma ~ N**+(0, 1)",
29462946
r"mu ~ Deterministic(f(beta, alpha))",
29472947
r"beta ~ N(0, 10)",
2948-
r"Z ~ N(<constant>, f())",
2948+
r"Z ~ N(f(), f())",
29492949
r"nb_with_p_n ~ NB(10, nbp)",
29502950
r"Y_obs ~ N(mu, sigma)",
29512951
r"pot ~ Potential(f(beta, alpha))",
@@ -2965,7 +2965,7 @@ def setup_class(self):
29652965
r"$\text{sigma} \sim \operatorname{N^{+}}(0,~1)$",
29662966
r"$\text{mu} \sim \operatorname{Deterministic}(f(\text{beta},~\text{alpha}))$",
29672967
r"$\text{beta} \sim \operatorname{N}(0,~10)$",
2968-
r"$\text{Z} \sim \operatorname{N}(\text{<constant>},~f())$",
2968+
r"$\text{Z} \sim \operatorname{N}(f(),~f())$",
29692969
r"$\text{nb_with_p_n} \sim \operatorname{NB}(10,~\text{nbp})$",
29702970
r"$\text{Y_obs} \sim \operatorname{N}(\text{mu},~\text{sigma})$",
29712971
r"$\text{pot} \sim \operatorname{Potential}(f(\text{beta},~\text{alpha}))$",

pymc/tests/test_distributions_moments.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -950,7 +950,7 @@ def test_mv_normal_moment(mu, cov, size, expected):
950950
with Model() as model:
951951
x = MvNormal("x", mu=mu, cov=cov, size=size)
952952

953-
# MvNormal logp is only impemented for up to 2D variables
953+
# MvNormal logp is only implemented for up to 2D variables
954954
assert_moment_is_expected(model, expected, check_finite_logp=x.ndim < 3)
955955

956956

pymc/tests/test_distributions_random.py

+48
Original file line numberDiff line numberDiff line change
@@ -1050,8 +1050,32 @@ class TestMvNormalCov(BaseTestDistribution):
10501050
"check_pymc_params_match_rv_op",
10511051
"check_pymc_draws_match_reference",
10521052
"check_rv_size",
1053+
"check_mu_broadcast_helper",
10531054
]
10541055

1056+
def check_mu_broadcast_helper(self):
1057+
"""Test that mu is broadcasted to the shape of cov"""
1058+
x = pm.MvNormal.dist(mu=1, cov=np.eye(3))
1059+
mu = x.owner.inputs[3]
1060+
assert mu.eval().shape == (3,)
1061+
1062+
x = pm.MvNormal.dist(mu=np.ones(1), cov=np.eye(3))
1063+
mu = x.owner.inputs[3]
1064+
assert mu.eval().shape == (3,)
1065+
1066+
x = pm.MvNormal.dist(mu=np.ones((1, 1)), cov=np.eye(3))
1067+
mu = x.owner.inputs[3]
1068+
assert mu.eval().shape == (1, 3)
1069+
1070+
x = pm.MvNormal.dist(mu=np.ones((10, 1)), cov=np.eye(3))
1071+
mu = x.owner.inputs[3]
1072+
assert mu.eval().shape == (10, 3)
1073+
1074+
# Cov is artificually limited to being 2D
1075+
# x = pm.MvNormal.dist(mu=np.ones((10, 1)), cov=np.full((2, 3, 3), np.eye(3)))
1076+
# mu = x.owner.inputs[3]
1077+
# assert mu.eval().shape == (10, 2, 3)
1078+
10551079

10561080
class TestMvNormalChol(BaseTestDistribution):
10571081
pymc_dist = pm.MvNormal
@@ -1111,6 +1135,7 @@ def mvstudentt_rng_fn(self, size, nu, mu, cov, rng):
11111135
"check_pymc_draws_match_reference",
11121136
"check_rv_size",
11131137
"check_errors",
1138+
"check_mu_broadcast_helper",
11141139
]
11151140

11161141
def check_errors(self):
@@ -1124,6 +1149,29 @@ def check_errors(self):
11241149
cov=np.full((2, 2), np.ones(2)),
11251150
)
11261151

1152+
def check_mu_broadcast_helper(self):
1153+
"""Test that mu is broadcasted to the shape of cov"""
1154+
x = pm.MvStudentT.dist(nu=4, mu=1, cov=np.eye(3))
1155+
mu = x.owner.inputs[4]
1156+
assert mu.eval().shape == (3,)
1157+
1158+
x = pm.MvStudentT.dist(nu=4, mu=np.ones(1), cov=np.eye(3))
1159+
mu = x.owner.inputs[4]
1160+
assert mu.eval().shape == (3,)
1161+
1162+
x = pm.MvStudentT.dist(nu=4, mu=np.ones((1, 1)), cov=np.eye(3))
1163+
mu = x.owner.inputs[4]
1164+
assert mu.eval().shape == (1, 3)
1165+
1166+
x = pm.MvStudentT.dist(nu=4, mu=np.ones((10, 1)), cov=np.eye(3))
1167+
mu = x.owner.inputs[4]
1168+
assert mu.eval().shape == (10, 3)
1169+
1170+
# Cov is artificually limited to being 2D
1171+
# x = pm.MvStudentT.dist(nu=4, mu=np.ones((10, 1)), cov=np.full((2, 3, 3), np.eye(3)))
1172+
# mu = x.owner.inputs[4]
1173+
# assert mu.eval().shape == (10, 2, 3)
1174+
11271175

11281176
class TestMvStudentTChol(BaseTestDistribution):
11291177
pymc_dist = pm.MvStudentT

0 commit comments

Comments
 (0)