Skip to content

Commit b71bb74

Browse files
Add StickBreakingWeights distribution (#5200)
1 parent d52655d commit b71bb74

File tree

6 files changed

+289
-1
lines changed

6 files changed

+289
-1
lines changed

docs/source/api/distributions/multivariate.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,4 @@ Multivariate
1919
MatrixNormal
2020
KroneckerNormal
2121
CAR
22+
StickBreakingWeights

pymc/distributions/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@
9595
MvNormal,
9696
MvStudentT,
9797
OrderedMultinomial,
98+
StickBreakingWeights,
9899
Wishart,
99100
WishartBartlett,
100101
)
@@ -159,6 +160,7 @@
159160
"KroneckerNormal",
160161
"MvStudentT",
161162
"Dirichlet",
163+
"StickBreakingWeights",
162164
"Multinomial",
163165
"DirichletMultinomial",
164166
"OrderedMultinomial",

pymc/distributions/multivariate.py

Lines changed: 186 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,13 @@
4444
from pymc.aesaraf import floatX, intX
4545
from pymc.distributions import transforms
4646
from pymc.distributions.continuous import ChiSquared, Normal, assert_negative_support
47-
from pymc.distributions.dist_math import check_parameters, factln, logpow, multigammaln
47+
from pymc.distributions.dist_math import (
48+
betaln,
49+
check_parameters,
50+
factln,
51+
logpow,
52+
multigammaln,
53+
)
4854
from pymc.distributions.distribution import Continuous, Discrete
4955
from pymc.distributions.shape_utils import (
5056
broadcast_dist_samples_to,
@@ -67,6 +73,7 @@
6773
"MatrixNormal",
6874
"KroneckerNormal",
6975
"CAR",
76+
"StickBreakingWeights",
7077
]
7178

7279
# Step methods and advi do not catch LinAlgErrors at the
@@ -2167,3 +2174,181 @@ def logp(value, mu, W, alpha, tau):
21672174
tau > 0,
21682175
msg="-1 <= alpha <= 1, tau > 0",
21692176
)
2177+
2178+
2179+
class StickBreakingWeightsRV(RandomVariable):
2180+
name = "stick_breaking_weights"
2181+
ndim_supp = 1
2182+
ndims_params = [0, 0]
2183+
dtype = "floatX"
2184+
_print_name = ("StickBreakingWeights", "\\operatorname{StickBreakingWeights}")
2185+
2186+
def make_node(self, rng, size, dtype, alpha, K):
2187+
2188+
alpha = at.as_tensor_variable(alpha)
2189+
K = at.as_tensor_variable(intX(K))
2190+
2191+
if alpha.ndim > 0:
2192+
raise ValueError("The concentration parameter needs to be a scalar.")
2193+
2194+
if K.ndim > 0:
2195+
raise ValueError("K must be a scalar.")
2196+
2197+
return super().make_node(rng, size, dtype, alpha, K)
2198+
2199+
def _infer_shape(self, size, dist_params, param_shapes=None):
2200+
alpha, K = dist_params
2201+
2202+
size = tuple(size)
2203+
2204+
return size + (K + 1,)
2205+
2206+
@classmethod
2207+
def rng_fn(cls, rng, alpha, K, size):
2208+
if K < 0:
2209+
raise ValueError("K needs to be positive.")
2210+
2211+
if size is None:
2212+
size = (K,)
2213+
elif isinstance(size, int):
2214+
size = (size,) + (K,)
2215+
else:
2216+
size = tuple(size) + (K,)
2217+
2218+
betas = rng.beta(1, alpha, size=size)
2219+
2220+
sticks = np.concatenate(
2221+
(
2222+
np.ones(shape=(size[:-1] + (1,))),
2223+
np.cumprod(1 - betas[..., :-1], axis=-1),
2224+
),
2225+
axis=-1,
2226+
)
2227+
2228+
weights = sticks * betas
2229+
weights = np.concatenate(
2230+
(weights, 1 - weights.sum(axis=-1)[..., np.newaxis]),
2231+
axis=-1,
2232+
)
2233+
2234+
return weights
2235+
2236+
2237+
stickbreakingweights = StickBreakingWeightsRV()
2238+
2239+
2240+
class StickBreakingWeights(Continuous):
2241+
r"""
2242+
Likelihood of truncated stick-breaking weights. The weights are generated from a
2243+
stick-breaking proceduce where :math:`x_k = v_k \prod_{\ell < k} (1 - v_\ell)` for
2244+
:math:`k \in \{1, \ldots, K\}` and :math:`x_K = \prod_{\ell = 1}^{K} (1 - v_\ell) = 1 - \sum_{\ell=1}^K x_\ell`
2245+
with :math:`v_k \stackrel{\text{i.i.d.}}{\sim} \text{Beta}(1, \alpha)`.
2246+
2247+
.. math:
2248+
2249+
f(\mathbf{x}|\alpha, K) =
2250+
B(1, \alpha)^{-K}x_{K+1}^\alpha \prod_{k=1}^{K+1}\left\{\sum_{j=k}^{K+1} x_j\right\}^{-1}
2251+
2252+
======== ===============================================
2253+
Support :math:`x_k \in (0, 1)` for :math:`k \in \{1, \ldots, K+1\}`
2254+
such that :math:`\sum x_k = 1`
2255+
Mean :math:`\mathbb{E}[x_k] = \dfrac{1}{1 + \alpha}\left(\dfrac{\alpha}{1 + \alpha}\right)^{k - 1}`
2256+
for :math:`k \in \{1, \ldots, K\}` and :math:`\mathbb{E}[x_{K+1}] = \left(\dfrac{\alpha}{1 + \alpha}\right)^{K}`
2257+
======== ===============================================
2258+
2259+
Parameters
2260+
----------
2261+
alpha: float
2262+
Concentration parameter (alpha > 0).
2263+
K: int
2264+
The number of "sticks" to break off from an initial one-unit stick. The length of the weight
2265+
vector is K + 1, where the last weight is one minus the sum of all the first sticks.
2266+
2267+
References
2268+
----------
2269+
.. [1] Ishwaran, H., & James, L. F. (2001). Gibbs sampling methods for stick-breaking priors.
2270+
Journal of the American Statistical Association, 96(453), 161-173.
2271+
2272+
.. [2] Müller, P., Quintana, F. A., Jara, A., & Hanson, T. (2015). Bayesian nonparametric data
2273+
analysis. New York: Springer.
2274+
"""
2275+
rv_op = stickbreakingweights
2276+
2277+
def __new__(cls, name, *args, **kwargs):
2278+
kwargs.setdefault("transform", transforms.simplex)
2279+
return super().__new__(cls, name, *args, **kwargs)
2280+
2281+
@classmethod
2282+
def dist(cls, alpha, K, *args, **kwargs):
2283+
alpha = at.as_tensor_variable(floatX(alpha))
2284+
K = at.as_tensor_variable(intX(K))
2285+
2286+
assert_negative_support(alpha, "alpha", "StickBreakingWeights")
2287+
assert_negative_support(K, "K", "StickBreakingWeights")
2288+
2289+
return super().dist([alpha, K], **kwargs)
2290+
2291+
def get_moment(rv, size, alpha, K):
2292+
moment = (alpha / (1 + alpha)) ** at.arange(K)
2293+
moment *= 1 / (1 + alpha)
2294+
moment = at.concatenate([moment, [(alpha / (1 + alpha)) ** K]], axis=-1)
2295+
if not rv_size_is_none(size):
2296+
moment_size = at.concatenate(
2297+
[
2298+
size,
2299+
[
2300+
K + 1,
2301+
],
2302+
]
2303+
)
2304+
moment = at.full(moment_size, moment)
2305+
2306+
return moment
2307+
2308+
def logp(value, alpha, K):
2309+
"""
2310+
Calculate log-probability of the distribution induced from the stick-breaking process
2311+
at specified value.
2312+
2313+
Parameters
2314+
----------
2315+
value: numeric
2316+
Value for which log-probability is calculated.
2317+
2318+
Returns
2319+
-------
2320+
TensorVariable
2321+
"""
2322+
logp = -at.sum(
2323+
at.log(
2324+
at.cumsum(
2325+
value[..., ::-1],
2326+
axis=-1,
2327+
)
2328+
),
2329+
axis=-1,
2330+
)
2331+
logp += -K * betaln(1, alpha)
2332+
logp += alpha * at.log(value[..., -1])
2333+
2334+
logp = at.switch(
2335+
at.or_(
2336+
at.any(
2337+
at.and_(at.le(value, 0), at.ge(value, 1)),
2338+
axis=-1,
2339+
),
2340+
at.or_(
2341+
at.bitwise_not(at.allclose(value.sum(-1), 1)),
2342+
at.neq(value.shape[-1], K + 1),
2343+
),
2344+
),
2345+
-np.inf,
2346+
logp,
2347+
)
2348+
2349+
return check_parameters(
2350+
logp,
2351+
alpha > 0,
2352+
K > 0,
2353+
msg="alpha > 0, K > 0",
2354+
)

pymc/tests/test_distributions.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ def polyagamma_cdf(*args, **kwargs):
109109
Poisson,
110110
Rice,
111111
SkewNormal,
112+
StickBreakingWeights,
112113
StudentT,
113114
Triangular,
114115
TruncatedNormal,
@@ -2123,6 +2124,40 @@ def test_dirichlet_invalid(self):
21232124
valid_dist = Dirichlet.dist(a=[1, 1, 1])
21242125
assert np.all(np.isfinite(pm.logp(valid_dist, value).eval()) == np.array([True, False]))
21252126

2127+
@pytest.mark.parametrize(
2128+
"value,alpha,K,logp",
2129+
[
2130+
(np.array([5, 4, 3, 2, 1]) / 15, 0.5, 4, 1.5126301307277439),
2131+
(np.tile(1, 13) / 13, 2, 12, 13.980045245672827),
2132+
(np.array([0.001] * 10 + [0.99]), 0.1, 10, -22.971662448814723),
2133+
(np.append(0.5 ** np.arange(1, 20), 0.5 ** 20), 5, 19, 94.20462772778092),
2134+
(
2135+
(np.array([[7, 5, 3, 2], [19, 17, 13, 11]]) / np.array([[17], [60]])),
2136+
2.5,
2137+
3,
2138+
np.array([1.29317672, 1.50126157]),
2139+
),
2140+
],
2141+
)
2142+
def test_stickbreakingweights_logp(self, value, alpha, K, logp):
2143+
with Model() as model:
2144+
sbw = StickBreakingWeights("sbw", alpha=alpha, K=K, transform=None)
2145+
pt = {"sbw": value}
2146+
assert_almost_equal(
2147+
pm.logp(sbw, value).eval(),
2148+
logp,
2149+
decimal=select_by_precision(float64=6, float32=2),
2150+
err_msg=str(pt),
2151+
)
2152+
2153+
def test_stickbreakingweights_invalid(self):
2154+
sbw = pm.StickBreakingWeights.dist(3.0, 3)
2155+
sbw_wrong_K = pm.StickBreakingWeights.dist(3.0, 7)
2156+
assert pm.logp(sbw, np.array([0.4, 0.3, 0.2, 0.15])).eval() == -np.inf
2157+
assert pm.logp(sbw, np.array([1.1, 0.3, 0.2, 0.1])).eval() == -np.inf
2158+
assert pm.logp(sbw, np.array([0.4, 0.3, 0.2, -0.1])).eval() == -np.inf
2159+
assert pm.logp(sbw_wrong_K, np.array([0.4, 0.3, 0.2, 0.1])).eval() == -np.inf
2160+
21262161
@pytest.mark.parametrize(
21272162
"a",
21282163
[

pymc/tests/test_distributions_moments.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
Rice,
5353
Simulator,
5454
SkewNormal,
55+
StickBreakingWeights,
5556
StudentT,
5657
Triangular,
5758
TruncatedNormal,
@@ -1087,6 +1088,35 @@ def test_matrixnormal_moment(mu, rowchol, colchol, size, expected):
10871088
def test_rice_moment(nu, sigma, size, expected):
10881089
with Model() as model:
10891090
Rice("x", nu=nu, sigma=sigma, size=size)
1091+
1092+
1093+
@pytest.mark.parametrize(
1094+
"alpha, K, size, expected",
1095+
[
1096+
(3, 11, None, np.append((3 / 4) ** np.arange(11) * 1 / 4, (3 / 4) ** 11)),
1097+
(5, 19, None, np.append((5 / 6) ** np.arange(19) * 1 / 6, (5 / 6) ** 19)),
1098+
(
1099+
1,
1100+
7,
1101+
(13,),
1102+
np.full(
1103+
shape=(13, 8), fill_value=np.append((1 / 2) ** np.arange(7) * 1 / 2, (1 / 2) ** 7)
1104+
),
1105+
),
1106+
(
1107+
0.5,
1108+
5,
1109+
(3, 5, 7),
1110+
np.full(
1111+
shape=(3, 5, 7, 6),
1112+
fill_value=np.append((1 / 3) ** np.arange(5) * 2 / 3, (1 / 3) ** 5),
1113+
),
1114+
),
1115+
],
1116+
)
1117+
def test_stickbreakingweights_moment(alpha, K, size, expected):
1118+
with Model() as model:
1119+
StickBreakingWeights("x", alpha=alpha, K=K, size=size)
10901120
assert_moment_is_expected(model, expected)
10911121

10921122

pymc/tests/test_distributions_random.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1176,6 +1176,41 @@ class TestDirichlet(BaseTestDistribution):
11761176
]
11771177

11781178

1179+
class TestStickBreakingWeights(BaseTestDistribution):
1180+
pymc_dist = pm.StickBreakingWeights
1181+
pymc_dist_params = {"alpha": 2.0, "K": 19}
1182+
expected_rv_op_params = {"alpha": 2.0, "K": 19}
1183+
sizes_to_check = [None, 17, (5,), (11, 5), (3, 13, 5)]
1184+
sizes_expected = [
1185+
(20,),
1186+
(17, 20),
1187+
(
1188+
5,
1189+
20,
1190+
),
1191+
(11, 5, 20),
1192+
(3, 13, 5, 20),
1193+
]
1194+
tests_to_run = [
1195+
"check_pymc_params_match_rv_op",
1196+
"check_rv_size",
1197+
"check_basic_properties",
1198+
]
1199+
1200+
def check_basic_properties(self):
1201+
default_rng = aesara.shared(np.random.default_rng(1234))
1202+
draws = pm.StickBreakingWeights.dist(
1203+
alpha=3.5,
1204+
K=19,
1205+
size=(2, 3, 5),
1206+
rng=default_rng,
1207+
).eval()
1208+
1209+
assert np.allclose(draws.sum(-1), 1)
1210+
assert np.all(draws >= 0)
1211+
assert np.all(draws <= 1)
1212+
1213+
11791214
class TestMultinomial(BaseTestDistribution):
11801215
pymc_dist = pm.Multinomial
11811216
pymc_dist_params = {"n": 85, "p": np.array([0.28, 0.62, 0.10])}

0 commit comments

Comments
 (0)