|
21 | 21 | Constant,
|
22 | 22 | DensityDist,
|
23 | 23 | Dirichlet,
|
| 24 | + DirichletMultinomial, |
24 | 25 | DiscreteUniform,
|
25 | 26 | DiscreteWeibull,
|
26 | 27 | ExGaussian,
|
@@ -112,7 +113,6 @@ def test_all_distributions_have_moments():
|
112 | 113 |
|
113 | 114 | # Distributions that have been refactored but don't yet have moments
|
114 | 115 | not_implemented |= {
|
115 |
| - dist_module.multivariate.DirichletMultinomial, |
116 | 116 | dist_module.multivariate.Wishart,
|
117 | 117 | }
|
118 | 118 |
|
@@ -797,10 +797,7 @@ def test_discrete_weibull_moment(q, beta, size, expected):
|
797 | 797 | ),
|
798 | 798 | (
|
799 | 799 | np.full(shape=np.array([7, 3]), fill_value=np.array([13, 17, 19])),
|
800 |
| - ( |
801 |
| - 11, |
802 |
| - 5, |
803 |
| - ), |
| 800 | + (11, 5), |
804 | 801 | np.broadcast_to([13, 17, 19], shape=[11, 5, 7, 3]) / 49,
|
805 | 802 | ),
|
806 | 803 | ],
|
@@ -1461,3 +1458,32 @@ def test_lkjcholeskycov_moment(n, eta, size, expected):
|
1461 | 1458 | sd_dist = pm.Exponential.dist(1, size=(*to_tuple(size), n))
|
1462 | 1459 | LKJCholeskyCov("x", n=n, eta=eta, sd_dist=sd_dist, size=size, compute_corr=False)
|
1463 | 1460 | assert_moment_is_expected(model, expected, check_finite_logp=size is None)
|
| 1461 | + |
| 1462 | + |
| 1463 | +@pytest.mark.parametrize( |
| 1464 | + "a, n, size, expected", |
| 1465 | + [ |
| 1466 | + (np.array([2, 2, 2, 2]), 1, None, np.array([1, 0, 0, 0])), |
| 1467 | + (np.array([3, 6, 0.5, 0.5]), 2, None, np.array([1, 1, 0, 0])), |
| 1468 | + (np.array([30, 60, 5, 5]), 10, None, np.array([4, 6, 0, 0])), |
| 1469 | + ( |
| 1470 | + np.array([[26, 26, 26, 22]]), # Dim: 1 x 4 |
| 1471 | + np.array([[1], [10]]), # Dim: 2 x 1 |
| 1472 | + None, |
| 1473 | + np.array([[[1, 0, 0, 0]], [[2, 3, 3, 2]]]), # Dim: 2 x 1 x 4 |
| 1474 | + ), |
| 1475 | + ( |
| 1476 | + np.array([[26, 26, 26, 22]]), # Dim: 1 x 4 |
| 1477 | + np.array([[1], [10]]), # Dim: 2 x 1 |
| 1478 | + (2, 1), |
| 1479 | + np.full( |
| 1480 | + (2, 1, 2, 1, 4), |
| 1481 | + np.array([[[1, 0, 0, 0]], [[2, 3, 3, 2]]]), # Dim: 2 x 1 x 4 |
| 1482 | + ), |
| 1483 | + ), |
| 1484 | + ], |
| 1485 | +) |
| 1486 | +def test_dirichlet_multinomial_moment(a, n, size, expected): |
| 1487 | + with Model() as model: |
| 1488 | + DirichletMultinomial("x", n=n, a=a, size=size) |
| 1489 | + assert_moment_is_expected(model, expected) |
0 commit comments