Skip to content

Commit da2030e

Browse files
morganstromricardoV94
authored andcommitted
Adds tests and mode for dirichlet multinomial distribution
1 parent 24f9bd4 commit da2030e

File tree

2 files changed

+45
-8
lines changed

2 files changed

+45
-8
lines changed

pymc/distributions/multivariate.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -453,9 +453,7 @@ def get_moment(rv, size, a):
453453
norm_constant = at.sum(a, axis=-1)[..., None]
454454
moment = a / norm_constant
455455
if not rv_size_is_none(size):
456-
if isinstance(size, int):
457-
size = (size,)
458-
moment = at.full((*size, *a.shape), moment)
456+
moment = at.full(at.concatenate([size, a.shape]), moment)
459457
return moment
460458

461459
def logp(value, a):
@@ -684,6 +682,19 @@ def dist(cls, n, a, *args, **kwargs):
684682

685683
return super().dist([n, a], **kwargs)
686684

685+
def get_moment(rv, size, n, a):
686+
p = a / at.sum(a, axis=-1)
687+
mode = at.round(n * p)
688+
diff = n - at.sum(mode, axis=-1, keepdims=True)
689+
inc_bool_arr = at.abs_(diff) > 0
690+
mode = at.inc_subtensor(mode[inc_bool_arr.nonzero()], diff[inc_bool_arr.nonzero()])
691+
# Reshape mode according to base shape (ignoring size)
692+
mode = at.reshape(mode, rv.shape[size.size :])
693+
if not rv_size_is_none(size):
694+
output_size = at.concatenate([size, mode.shape])
695+
mode = at.full(output_size, mode)
696+
return mode
697+
687698
def logp(value, n, a):
688699
"""
689700
Calculate log-probability of DirichletMultinomial distribution

pymc/tests/test_distributions_moments.py

+31-5
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
Constant,
2222
DensityDist,
2323
Dirichlet,
24+
DirichletMultinomial,
2425
DiscreteUniform,
2526
DiscreteWeibull,
2627
ExGaussian,
@@ -112,7 +113,6 @@ def test_all_distributions_have_moments():
112113

113114
# Distributions that have been refactored but don't yet have moments
114115
not_implemented |= {
115-
dist_module.multivariate.DirichletMultinomial,
116116
dist_module.multivariate.Wishart,
117117
}
118118

@@ -797,10 +797,7 @@ def test_discrete_weibull_moment(q, beta, size, expected):
797797
),
798798
(
799799
np.full(shape=np.array([7, 3]), fill_value=np.array([13, 17, 19])),
800-
(
801-
11,
802-
5,
803-
),
800+
(11, 5),
804801
np.broadcast_to([13, 17, 19], shape=[11, 5, 7, 3]) / 49,
805802
),
806803
],
@@ -1461,3 +1458,32 @@ def test_lkjcholeskycov_moment(n, eta, size, expected):
14611458
sd_dist = pm.Exponential.dist(1, size=(*to_tuple(size), n))
14621459
LKJCholeskyCov("x", n=n, eta=eta, sd_dist=sd_dist, size=size, compute_corr=False)
14631460
assert_moment_is_expected(model, expected, check_finite_logp=size is None)
1461+
1462+
1463+
@pytest.mark.parametrize(
1464+
"a, n, size, expected",
1465+
[
1466+
(np.array([2, 2, 2, 2]), 1, None, np.array([1, 0, 0, 0])),
1467+
(np.array([3, 6, 0.5, 0.5]), 2, None, np.array([1, 1, 0, 0])),
1468+
(np.array([30, 60, 5, 5]), 10, None, np.array([4, 6, 0, 0])),
1469+
(
1470+
np.array([[26, 26, 26, 22]]), # Dim: 1 x 4
1471+
np.array([[1], [10]]), # Dim: 2 x 1
1472+
None,
1473+
np.array([[[1, 0, 0, 0]], [[2, 3, 3, 2]]]), # Dim: 2 x 1 x 4
1474+
),
1475+
(
1476+
np.array([[26, 26, 26, 22]]), # Dim: 1 x 4
1477+
np.array([[1], [10]]), # Dim: 2 x 1
1478+
(2, 1),
1479+
np.full(
1480+
(2, 1, 2, 1, 4),
1481+
np.array([[[1, 0, 0, 0]], [[2, 3, 3, 2]]]), # Dim: 2 x 1 x 4
1482+
),
1483+
),
1484+
],
1485+
)
1486+
def test_dirichlet_multinomial_moment(a, n, size, expected):
1487+
with Model() as model:
1488+
DirichletMultinomial("x", n=n, a=a, size=size)
1489+
assert_moment_is_expected(model, expected)

0 commit comments

Comments
 (0)