|
42 | 42 | )
|
43 | 43 | from pymc3.distributions.shape_utils import broadcast_dist_samples_to, to_tuple
|
44 | 44 | from pymc3.distributions.special import gammaln, multigammaln
|
| 45 | +from pymc3.exceptions import ShapeError |
45 | 46 | from pymc3.math import kron_diag, kron_dot, kron_solve_lower, kronecker
|
46 | 47 | from pymc3.model import Deterministic
|
47 |
| -from pymc3.theanof import floatX |
| 48 | +from pymc3.theanof import floatX, intX |
48 | 49 |
|
49 | 50 | __all__ = [
|
50 | 51 | "MvNormal",
|
51 | 52 | "MvStudentT",
|
52 | 53 | "Dirichlet",
|
53 | 54 | "Multinomial",
|
| 55 | + "DirichletMultinomial", |
54 | 56 | "Wishart",
|
55 | 57 | "WishartBartlett",
|
56 | 58 | "LKJCorr",
|
@@ -690,6 +692,160 @@ def logp(self, x):
|
690 | 692 | )
|
691 | 693 |
|
692 | 694 |
|
| 695 | +class DirichletMultinomial(Discrete): |
| 696 | + R"""Dirichlet Multinomial log-likelihood. |
| 697 | +
|
| 698 | + Dirichlet mixture of Multinomials distribution, with a marginalized PMF. |
| 699 | +
|
| 700 | + .. math:: |
| 701 | +
|
| 702 | + f(x \mid n, a) = \frac{\Gamma(n + 1)\Gamma(\sum a_k)} |
| 703 | + {\Gamma(\n + \sum a_k)} |
| 704 | + \prod_{k=1}^K |
| 705 | + \frac{\Gamma(x_k + a_k)} |
| 706 | + {\Gamma(x_k + 1)\Gamma(a_k)} |
| 707 | +
|
| 708 | + ========== =========================================== |
| 709 | + Support :math:`x \in \{0, 1, \ldots, n\}` such that |
| 710 | + :math:`\sum x_i = n` |
| 711 | + Mean :math:`n \frac{a_i}{\sum{a_k}}` |
| 712 | + ========== =========================================== |
| 713 | +
|
| 714 | + Parameters |
| 715 | + ---------- |
| 716 | + n : int or array |
| 717 | + Total counts in each replicate. If n is an array its shape must be (N,) |
| 718 | + with N = a.shape[0] |
| 719 | +
|
| 720 | + a : one- or two-dimensional array |
| 721 | + Dirichlet parameter. Elements must be strictly positive. |
| 722 | + The number of categories is given by the length of the last axis. |
| 723 | +
|
| 724 | + shape : integer tuple |
| 725 | + Describes shape of distribution. For example if n=array([5, 10]), and |
| 726 | + a=array([1, 1, 1]), shape should be (2, 3). |
| 727 | + """ |
| 728 | + |
| 729 | + def __init__(self, n, a, shape, *args, **kwargs): |
| 730 | + |
| 731 | + super().__init__(shape=shape, defaults=("_defaultval",), *args, **kwargs) |
| 732 | + |
| 733 | + n = intX(n) |
| 734 | + a = floatX(a) |
| 735 | + if len(self.shape) > 1: |
| 736 | + self.n = tt.shape_padright(n) |
| 737 | + self.a = tt.as_tensor_variable(a) if a.ndim > 1 else tt.shape_padleft(a) |
| 738 | + else: |
| 739 | + # n is a scalar, p is a 1d array |
| 740 | + self.n = tt.as_tensor_variable(n) |
| 741 | + self.a = tt.as_tensor_variable(a) |
| 742 | + |
| 743 | + p = self.a / self.a.sum(-1, keepdims=True) |
| 744 | + |
| 745 | + self.mean = self.n * p |
| 746 | + # Mode is only an approximation. Exact computation requires a complex |
| 747 | + # iterative algorithm as described in https://doi.org/10.1016/j.spl.2009.09.013 |
| 748 | + mode = tt.cast(tt.round(self.mean), "int32") |
| 749 | + diff = self.n - tt.sum(mode, axis=-1, keepdims=True) |
| 750 | + inc_bool_arr = tt.abs_(diff) > 0 |
| 751 | + mode = tt.inc_subtensor(mode[inc_bool_arr.nonzero()], diff[inc_bool_arr.nonzero()]) |
| 752 | + self._defaultval = mode |
| 753 | + |
| 754 | + def _random(self, n, a, size=None): |
| 755 | + # numpy will cast dirichlet and multinomial samples to float64 by default |
| 756 | + original_dtype = a.dtype |
| 757 | + |
| 758 | + # Thanks to the default shape handling done in generate_values, the last |
| 759 | + # axis of n is a dummy axis that allows it to broadcast well with `a` |
| 760 | + n = np.broadcast_to(n, size) |
| 761 | + a = np.broadcast_to(a, size) |
| 762 | + n = n[..., 0] |
| 763 | + |
| 764 | + # np.random.multinomial needs `n` to be a scalar int and `a` a |
| 765 | + # sequence so we semi flatten them and iterate over them |
| 766 | + n_ = n.reshape([-1]) |
| 767 | + a_ = a.reshape([-1, a.shape[-1]]) |
| 768 | + p_ = np.array([np.random.dirichlet(aa) for aa in a_]) |
| 769 | + samples = np.array([np.random.multinomial(nn, pp) for nn, pp in zip(n_, p_)]) |
| 770 | + samples = samples.reshape(a.shape) |
| 771 | + |
| 772 | + # We cast back to the original dtype |
| 773 | + return samples.astype(original_dtype) |
| 774 | + |
| 775 | + def random(self, point=None, size=None): |
| 776 | + """ |
| 777 | + Draw random values from Dirichlet-Multinomial distribution. |
| 778 | +
|
| 779 | + Parameters |
| 780 | + ---------- |
| 781 | + point: dict, optional |
| 782 | + Dict of variable values on which random values are to be |
| 783 | + conditioned (uses default point if not specified). |
| 784 | + size: int, optional |
| 785 | + Desired size of random sample (returns one sample if not |
| 786 | + specified). |
| 787 | +
|
| 788 | + Returns |
| 789 | + ------- |
| 790 | + array |
| 791 | + """ |
| 792 | + n, a = draw_values([self.n, self.a], point=point, size=size) |
| 793 | + samples = generate_samples( |
| 794 | + self._random, |
| 795 | + n, |
| 796 | + a, |
| 797 | + dist_shape=self.shape, |
| 798 | + size=size, |
| 799 | + ) |
| 800 | + |
| 801 | + # If distribution is initialized with .dist(), valid init shape is not asserted. |
| 802 | + # Under normal use in a model context valid init shape is asserted at start. |
| 803 | + expected_shape = to_tuple(size) + to_tuple(self.shape) |
| 804 | + sample_shape = tuple(samples.shape) |
| 805 | + if sample_shape != expected_shape: |
| 806 | + raise ShapeError( |
| 807 | + f"Expected sample shape was {expected_shape} but got {sample_shape}. " |
| 808 | + "This may reflect an invalid initialization shape." |
| 809 | + ) |
| 810 | + |
| 811 | + return samples |
| 812 | + |
| 813 | + def logp(self, value): |
| 814 | + """ |
| 815 | + Calculate log-probability of DirichletMultinomial distribution |
| 816 | + at specified value. |
| 817 | +
|
| 818 | + Parameters |
| 819 | + ---------- |
| 820 | + value: integer array |
| 821 | + Value for which log-probability is calculated. |
| 822 | +
|
| 823 | + Returns |
| 824 | + ------- |
| 825 | + TensorVariable |
| 826 | + """ |
| 827 | + a = self.a |
| 828 | + n = self.n |
| 829 | + sum_a = a.sum(axis=-1, keepdims=True) |
| 830 | + |
| 831 | + const = (gammaln(n + 1) + gammaln(sum_a)) - gammaln(n + sum_a) |
| 832 | + series = gammaln(value + a) - (gammaln(value + 1) + gammaln(a)) |
| 833 | + result = const + series.sum(axis=-1, keepdims=True) |
| 834 | + # Bounds checking to confirm parameters and data meet all constraints |
| 835 | + # and that each observation value_i sums to n_i. |
| 836 | + return bound( |
| 837 | + result, |
| 838 | + tt.all(tt.ge(value, 0)), |
| 839 | + tt.all(tt.gt(a, 0)), |
| 840 | + tt.all(tt.ge(n, 0)), |
| 841 | + tt.all(tt.eq(value.sum(axis=-1, keepdims=True), n)), |
| 842 | + broadcast_conditions=False, |
| 843 | + ) |
| 844 | + |
| 845 | + def _distr_parameters_for_repr(self): |
| 846 | + return ["n", "a"] |
| 847 | + |
| 848 | + |
693 | 849 | def posdef(AA):
|
694 | 850 | try:
|
695 | 851 | linalg.cholesky(AA)
|
|
0 commit comments