Skip to content

Commit 2a3d9a3

Browse files
ricardoV94bsmith89ColCarroll
authored
Dirichlet multinomial (continued) (#4373)
* Add implementation of DM distribution. * Fix class name mistake. * Add DM dist to exported multivariate distributions. * Export DirichletMultinomial in pymc3.distributions As suggested in #3639 (comment) Also see: #3639 (comment) but this seems to be part of a broader discussion. * Attempt at matching Multinomial initialization. * Add some simple tests for DM. * Correctly deal with 1d n and 2d alpha. * Fix typo in DM random. * Fix faulty tests for DM. * Drop redundant initialization test for DM. * Add test that DM is normalized for n=1 case. * Add DM test case based on BetaBinomial. * Update pymc3/distributions/multivariate.py * - Infer shape by default (copied code from Dirichlet Distribution) - Add default shape in `test_distributions_random.py` * - Use size information in random method - Change random unittests * - Restore merge accidental deletions * - Underscore missing * - More merge cleaning * Bring DirichletMultinomial initialization into alignment with Multinomial. * Align all DM tests with Multinomial. * Align DirichletMultinomial random implementation with Multinomial. * Match DM random method to Multinomial implementation. * Change alpha -> a Remove _repr_latex_ * Run pre-commit * Keep standard order of methods random and logp * Update docstrings for valid input types. Progress on batch test. * Add new test to ensure DM matches BetaBinom * Change DM alpha -> a in docstrings. * Test two additional parameterization shapes in `test_dirichlet_multinomial_random`. * Revert debugging comments. * Revert unrelated changes. * Fix minor Black inconsistency. * Drop no-longer-functional reshaping code. * Assert shape of random samples is as expected. * Explicitly test random sample shapes, including batch dimensions. * Sort imports. * Simplify _random It should be okay to not explicitly change the input dtype as in the multinomial, because the input to the np.random.dirichlet should be safe (it's fine to have float32 to float64 overflow from 1.00 to 1.01..., underflow from 0.01, to 0.0 would still be problematic, but we don't know if this is an issue yet...). The output of the numpy.random.dirichlet to numpy.random.multinomial should be safe since it is already in float64 by then. We still need to convert to the previous dtype, since numpy changes it by default. size_ argument was no longer being used. * Reorder tests more logically * Refactor tests Merged mode tests since shape must be given explicitly anyway Moved test_dirichlet_multinomial_random to test_distributions_random.py and renamed it to test_dirichlet_multinomial_shapes * Require shape argument Also allow more forgiveness if user passes lists instead of arrays (WIP/suggestion only) * Remove unused import `to_tuple` * Simplify logic to handle list as input for `a` * Raise ShapeError in random() * Finish batch and repr unittests * Add note about mode * Tiny rewording * Change mode to _defaultval * Revert comment for Multinomial mode * Update shape check logic * Add DM to release notes. * Minor docstring revisions as suggested by @AlexAndorra. * Revise the revision. * Add comment clarifying bounds checking in logp() * Address review suggestions * Update `matches_beta_binomial` to take into consideration float precision * Add DM to multivariate distributions docs. Co-authored-by: Byron Smith <[email protected]> Co-authored-by: Colin <[email protected]>
1 parent 1769258 commit 2a3d9a3

File tree

6 files changed

+390
-1
lines changed

6 files changed

+390
-1
lines changed

Diff for: RELEASE-NOTES.md

+1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ It also brings some dreadfully awaited fixes, so be sure to go through the chang
2222
- Add `logcdf` method to all univariate discrete distributions (see [#4387](https://github.com/pymc-devs/pymc3/pull/4387)).
2323
- Add `random` method to `MvGaussianRandomWalk` (see [#4388](https://github.com/pymc-devs/pymc3/pull/4388))
2424
- `AsymmetricLaplace` distribution added (see [#4392](https://github.com/pymc-devs/pymc3/pull/4392)).
25+
- `DirichletMultinomial` distribution added (see [#4373](https://github.com/pymc-devs/pymc3/pull/4373)).
2526

2627
### Maintenance
2728
- Fixed bug whereby partial traces returns after keyboard interrupt during parallel sampling had fewer draws than would've been available [#4318](https://github.com/pymc-devs/pymc3/pull/4318)

Diff for: docs/source/api/distributions/multivariate.rst

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ Multivariate
1414
LKJCorr
1515
Multinomial
1616
Dirichlet
17+
DirichletMultinomial
1718

1819
.. automodule:: pymc3.distributions.multivariate
1920
:members:

Diff for: pymc3/distributions/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181
from pymc3.distributions.mixture import Mixture, MixtureSameFamily, NormalMixture
8282
from pymc3.distributions.multivariate import (
8383
Dirichlet,
84+
DirichletMultinomial,
8485
KroneckerNormal,
8586
LKJCholeskyCov,
8687
LKJCorr,
@@ -155,6 +156,7 @@
155156
"MvStudentT",
156157
"Dirichlet",
157158
"Multinomial",
159+
"DirichletMultinomial",
158160
"Wishart",
159161
"WishartBartlett",
160162
"LKJCholeskyCov",

Diff for: pymc3/distributions/multivariate.py

+157-1
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,17 @@
4242
)
4343
from pymc3.distributions.shape_utils import broadcast_dist_samples_to, to_tuple
4444
from pymc3.distributions.special import gammaln, multigammaln
45+
from pymc3.exceptions import ShapeError
4546
from pymc3.math import kron_diag, kron_dot, kron_solve_lower, kronecker
4647
from pymc3.model import Deterministic
47-
from pymc3.theanof import floatX
48+
from pymc3.theanof import floatX, intX
4849

4950
__all__ = [
5051
"MvNormal",
5152
"MvStudentT",
5253
"Dirichlet",
5354
"Multinomial",
55+
"DirichletMultinomial",
5456
"Wishart",
5557
"WishartBartlett",
5658
"LKJCorr",
@@ -690,6 +692,160 @@ def logp(self, x):
690692
)
691693

692694

695+
class DirichletMultinomial(Discrete):
696+
R"""Dirichlet Multinomial log-likelihood.
697+
698+
Dirichlet mixture of Multinomials distribution, with a marginalized PMF.
699+
700+
.. math::
701+
702+
f(x \mid n, a) = \frac{\Gamma(n + 1)\Gamma(\sum a_k)}
703+
{\Gamma(\n + \sum a_k)}
704+
\prod_{k=1}^K
705+
\frac{\Gamma(x_k + a_k)}
706+
{\Gamma(x_k + 1)\Gamma(a_k)}
707+
708+
========== ===========================================
709+
Support :math:`x \in \{0, 1, \ldots, n\}` such that
710+
:math:`\sum x_i = n`
711+
Mean :math:`n \frac{a_i}{\sum{a_k}}`
712+
========== ===========================================
713+
714+
Parameters
715+
----------
716+
n : int or array
717+
Total counts in each replicate. If n is an array its shape must be (N,)
718+
with N = a.shape[0]
719+
720+
a : one- or two-dimensional array
721+
Dirichlet parameter. Elements must be strictly positive.
722+
The number of categories is given by the length of the last axis.
723+
724+
shape : integer tuple
725+
Describes shape of distribution. For example if n=array([5, 10]), and
726+
a=array([1, 1, 1]), shape should be (2, 3).
727+
"""
728+
729+
def __init__(self, n, a, shape, *args, **kwargs):
730+
731+
super().__init__(shape=shape, defaults=("_defaultval",), *args, **kwargs)
732+
733+
n = intX(n)
734+
a = floatX(a)
735+
if len(self.shape) > 1:
736+
self.n = tt.shape_padright(n)
737+
self.a = tt.as_tensor_variable(a) if a.ndim > 1 else tt.shape_padleft(a)
738+
else:
739+
# n is a scalar, p is a 1d array
740+
self.n = tt.as_tensor_variable(n)
741+
self.a = tt.as_tensor_variable(a)
742+
743+
p = self.a / self.a.sum(-1, keepdims=True)
744+
745+
self.mean = self.n * p
746+
# Mode is only an approximation. Exact computation requires a complex
747+
# iterative algorithm as described in https://doi.org/10.1016/j.spl.2009.09.013
748+
mode = tt.cast(tt.round(self.mean), "int32")
749+
diff = self.n - tt.sum(mode, axis=-1, keepdims=True)
750+
inc_bool_arr = tt.abs_(diff) > 0
751+
mode = tt.inc_subtensor(mode[inc_bool_arr.nonzero()], diff[inc_bool_arr.nonzero()])
752+
self._defaultval = mode
753+
754+
def _random(self, n, a, size=None):
755+
# numpy will cast dirichlet and multinomial samples to float64 by default
756+
original_dtype = a.dtype
757+
758+
# Thanks to the default shape handling done in generate_values, the last
759+
# axis of n is a dummy axis that allows it to broadcast well with `a`
760+
n = np.broadcast_to(n, size)
761+
a = np.broadcast_to(a, size)
762+
n = n[..., 0]
763+
764+
# np.random.multinomial needs `n` to be a scalar int and `a` a
765+
# sequence so we semi flatten them and iterate over them
766+
n_ = n.reshape([-1])
767+
a_ = a.reshape([-1, a.shape[-1]])
768+
p_ = np.array([np.random.dirichlet(aa) for aa in a_])
769+
samples = np.array([np.random.multinomial(nn, pp) for nn, pp in zip(n_, p_)])
770+
samples = samples.reshape(a.shape)
771+
772+
# We cast back to the original dtype
773+
return samples.astype(original_dtype)
774+
775+
def random(self, point=None, size=None):
776+
"""
777+
Draw random values from Dirichlet-Multinomial distribution.
778+
779+
Parameters
780+
----------
781+
point: dict, optional
782+
Dict of variable values on which random values are to be
783+
conditioned (uses default point if not specified).
784+
size: int, optional
785+
Desired size of random sample (returns one sample if not
786+
specified).
787+
788+
Returns
789+
-------
790+
array
791+
"""
792+
n, a = draw_values([self.n, self.a], point=point, size=size)
793+
samples = generate_samples(
794+
self._random,
795+
n,
796+
a,
797+
dist_shape=self.shape,
798+
size=size,
799+
)
800+
801+
# If distribution is initialized with .dist(), valid init shape is not asserted.
802+
# Under normal use in a model context valid init shape is asserted at start.
803+
expected_shape = to_tuple(size) + to_tuple(self.shape)
804+
sample_shape = tuple(samples.shape)
805+
if sample_shape != expected_shape:
806+
raise ShapeError(
807+
f"Expected sample shape was {expected_shape} but got {sample_shape}. "
808+
"This may reflect an invalid initialization shape."
809+
)
810+
811+
return samples
812+
813+
def logp(self, value):
814+
"""
815+
Calculate log-probability of DirichletMultinomial distribution
816+
at specified value.
817+
818+
Parameters
819+
----------
820+
value: integer array
821+
Value for which log-probability is calculated.
822+
823+
Returns
824+
-------
825+
TensorVariable
826+
"""
827+
a = self.a
828+
n = self.n
829+
sum_a = a.sum(axis=-1, keepdims=True)
830+
831+
const = (gammaln(n + 1) + gammaln(sum_a)) - gammaln(n + sum_a)
832+
series = gammaln(value + a) - (gammaln(value + 1) + gammaln(a))
833+
result = const + series.sum(axis=-1, keepdims=True)
834+
# Bounds checking to confirm parameters and data meet all constraints
835+
# and that each observation value_i sums to n_i.
836+
return bound(
837+
result,
838+
tt.all(tt.ge(value, 0)),
839+
tt.all(tt.gt(a, 0)),
840+
tt.all(tt.ge(n, 0)),
841+
tt.all(tt.eq(value.sum(axis=-1, keepdims=True), n)),
842+
broadcast_conditions=False,
843+
)
844+
845+
def _distr_parameters_for_repr(self):
846+
return ["n", "a"]
847+
848+
693849
def posdef(AA):
694850
try:
695851
linalg.cholesky(AA)

0 commit comments

Comments
 (0)