Skip to content

Check that concentration parameters of Dirichlet distribution are all > 0 #3853

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 3, 2020
Merged
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
- `pm.sample` now takes 1000 draws and 1000 tuning samples by default, instead of 500 previously (see [#3855](https://github.com/pymc-devs/pymc3/pull/3855)).
- Dropped the outdated 'nuts' initialization method for `pm.sample` (see [#3863](https://github.com/pymc-devs/pymc3/pull/3863)).
- Moved argument division out of `NegativeBinomial` `random` method. Fixes [#3864](https://github.com/pymc-devs/pymc3/issues/3864) in the style of [#3509](https://github.com/pymc-devs/pymc3/pull/3509).
- The Dirichlet distribution now raises a ValueError when it's initialized with <= 0 values (see [#3853](https://github.com/pymc-devs/pymc3/pull/3853)).

## PyMC3 3.8 (November 29 2019)

Expand Down
10 changes: 10 additions & 0 deletions pymc3/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,16 @@ class Dirichlet(Continuous):

def __init__(self, a, transform=transforms.stick_breaking,
*args, **kwargs):

if not isinstance(a, pm.model.TensorVariable):
if not isinstance(a, list) and not isinstance(a, np.ndarray):
raise TypeError(
'The vector of concentration parameters (a) must be a python list '
'or numpy array.')
a = np.array(a)
if (a <= 0).any():
raise ValueError("All concentration parameters (a) must be > 0.")

shape = np.atleast_1d(a.shape)[-1]

kwargs.setdefault("shape", shape)
Expand Down
38 changes: 32 additions & 6 deletions pymc3/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,17 +944,43 @@ def test_lkj(self, x, eta, n, lp):

@pytest.mark.parametrize('n', [2, 3])
def test_dirichlet(self, n):
self.pymc3_matches_scipy(Dirichlet, Simplex(
n), {'a': Vector(Rplus, n)}, dirichlet_logpdf)
self.pymc3_matches_scipy(
Dirichlet,
Simplex(n),
{'a': Vector(Rplus, n)},
dirichlet_logpdf
)

@pytest.mark.parametrize('n', [3, 4])
def test_dirichlet_init_fail(self, n):
with Model():
with pytest.raises(
ValueError,
match=r"All concentration parameters \(a\) must be > 0."
):
_ = Dirichlet('x', a=np.zeros(n), shape=n)
with pytest.raises(
ValueError,
match=r"All concentration parameters \(a\) must be > 0."
):
_ = Dirichlet('x', a=np.array([-1.] * n), shape=n)

def test_dirichlet_2D(self):
self.pymc3_matches_scipy(Dirichlet, MultiSimplex(2, 2),
{'a': Vector(Vector(Rplus, 2), 2)}, dirichlet_logpdf)
self.pymc3_matches_scipy(
Dirichlet,
MultiSimplex(2, 2),
{'a': Vector(Vector(Rplus, 2), 2)},
dirichlet_logpdf
)

@pytest.mark.parametrize('n', [2, 3])
def test_multinomial(self, n):
self.pymc3_matches_scipy(Multinomial, Vector(Nat, n), {'p': Simplex(n), 'n': Nat},
multinomial_logpdf)
self.pymc3_matches_scipy(
Multinomial,
Vector(Nat, n),
{'p': Simplex(n), 'n': Nat},
multinomial_logpdf
)

@pytest.mark.parametrize('p,n', [
[[.25, .25, .25, .25], 1],
Expand Down