Skip to content

Commit ef5653f

Browse files
committed
Fix KroneckerNormal ndim_supp
1 parent 0d86bdd commit ef5653f

File tree

2 files changed

+2
-6
lines changed

2 files changed

+2
-6
lines changed

pymc/distributions/multivariate.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -1829,14 +1829,11 @@ def logp(value, mu, rowchol, colchol):
18291829

18301830
class KroneckerNormalRV(RandomVariable):
18311831
name = "kroneckernormal"
1832-
ndim_supp = 2
1832+
ndim_supp = 1
18331833
ndims_params = [1, 0, 2]
18341834
dtype = "floatX"
18351835
_print_name = ("KroneckerNormal", "\\operatorname{KroneckerNormal}")
18361836

1837-
def _shape_from_params(self, dist_params, rep_param_idx=0, param_shapes=None):
1838-
return default_shape_from_params(1, dist_params, rep_param_idx, param_shapes)
1839-
18401837
def rng_fn(self, rng, mu, sigma, *covs, size=None):
18411838
size = size if size else covs[-1]
18421839
covs = covs[:-1] if covs[-1] == size else covs
@@ -1965,7 +1962,6 @@ def dist(cls, mu, covs=None, chols=None, evds=None, sigma=None, *args, **kwargs)
19651962

19661963
mu = at.as_tensor_variable(mu)
19671964

1968-
# mean = median = mode = mu
19691965
return super().dist([mu, sigma, *covs], **kwargs)
19701966

19711967
def get_moment(rv, size, mu, covs, chols, evds):

pymc/tests/test_distributions_moments.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1415,7 +1415,7 @@ def normal_sim(rng, mu, sigma, size):
14151415
),
14161416
],
14171417
)
1418-
def test_kronecker_normal_moments(mu, covs, size, expected):
1418+
def test_kronecker_normal_moment(mu, covs, size, expected):
14191419
with Model() as model:
14201420
KroneckerNormal("x", mu=mu, covs=covs, size=size)
14211421
assert_moment_is_expected(model, expected)

0 commit comments

Comments
 (0)