@@ -1050,8 +1050,32 @@ class TestMvNormalCov(BaseTestDistribution):
1050
1050
"check_pymc_params_match_rv_op" ,
1051
1051
"check_pymc_draws_match_reference" ,
1052
1052
"check_rv_size" ,
1053
+ "check_mu_broadcast_helper" ,
1053
1054
]
1054
1055
1056
+ def check_mu_broadcast_helper (self ):
1057
+ """Test that mu is broadcasted to the shape of cov"""
1058
+ x = pm .MvNormal .dist (mu = 1 , cov = np .eye (3 ))
1059
+ mu = x .owner .inputs [3 ]
1060
+ assert mu .eval ().shape == (3 ,)
1061
+
1062
+ x = pm .MvNormal .dist (mu = np .ones (1 ), cov = np .eye (3 ))
1063
+ mu = x .owner .inputs [3 ]
1064
+ assert mu .eval ().shape == (3 ,)
1065
+
1066
+ x = pm .MvNormal .dist (mu = np .ones ((1 , 1 )), cov = np .eye (3 ))
1067
+ mu = x .owner .inputs [3 ]
1068
+ assert mu .eval ().shape == (1 , 3 )
1069
+
1070
+ x = pm .MvNormal .dist (mu = np .ones ((10 , 1 )), cov = np .eye (3 ))
1071
+ mu = x .owner .inputs [3 ]
1072
+ assert mu .eval ().shape == (10 , 3 )
1073
+
1074
+ # Cov is artificually limited to being 2D
1075
+ # x = pm.MvNormal.dist(mu=np.ones((10, 1)), cov=np.full((2, 3, 3), np.eye(3)))
1076
+ # mu = x.owner.inputs[3]
1077
+ # assert mu.eval().shape == (10, 2, 3)
1078
+
1055
1079
1056
1080
class TestMvNormalChol (BaseTestDistribution ):
1057
1081
pymc_dist = pm .MvNormal
@@ -1111,6 +1135,7 @@ def mvstudentt_rng_fn(self, size, nu, mu, cov, rng):
1111
1135
"check_pymc_draws_match_reference" ,
1112
1136
"check_rv_size" ,
1113
1137
"check_errors" ,
1138
+ "check_mu_broadcast_helper" ,
1114
1139
]
1115
1140
1116
1141
def check_errors (self ):
@@ -1124,6 +1149,29 @@ def check_errors(self):
1124
1149
cov = np .full ((2 , 2 ), np .ones (2 )),
1125
1150
)
1126
1151
1152
+ def check_mu_broadcast_helper (self ):
1153
+ """Test that mu is broadcasted to the shape of cov"""
1154
+ x = pm .MvStudentT .dist (nu = 4 , mu = 1 , cov = np .eye (3 ))
1155
+ mu = x .owner .inputs [4 ]
1156
+ assert mu .eval ().shape == (3 ,)
1157
+
1158
+ x = pm .MvStudentT .dist (nu = 4 , mu = np .ones (1 ), cov = np .eye (3 ))
1159
+ mu = x .owner .inputs [4 ]
1160
+ assert mu .eval ().shape == (3 ,)
1161
+
1162
+ x = pm .MvStudentT .dist (nu = 4 , mu = np .ones ((1 , 1 )), cov = np .eye (3 ))
1163
+ mu = x .owner .inputs [4 ]
1164
+ assert mu .eval ().shape == (1 , 3 )
1165
+
1166
+ x = pm .MvStudentT .dist (nu = 4 , mu = np .ones ((10 , 1 )), cov = np .eye (3 ))
1167
+ mu = x .owner .inputs [4 ]
1168
+ assert mu .eval ().shape == (10 , 3 )
1169
+
1170
+ # Cov is artificually limited to being 2D
1171
+ # x = pm.MvStudentT.dist(nu=4, mu=np.ones((10, 1)), cov=np.full((2, 3, 3), np.eye(3)))
1172
+ # mu = x.owner.inputs[4]
1173
+ # assert mu.eval().shape == (10, 2, 3)
1174
+
1127
1175
1128
1176
class TestMvStudentTChol (BaseTestDistribution ):
1129
1177
pymc_dist = pm .MvStudentT
0 commit comments