Skip to content

Commit 854fd04

Browse files
zoj613ricardoV94
authored andcommitted
Add DiscreteWeibull moment
1 parent d8bf4ba commit 854fd04

File tree

2 files changed

+27
-1
lines changed

2 files changed

+27
-1
lines changed

pymc/distributions/discrete.py

+6
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,12 @@ def dist(cls, q, beta, *args, **kwargs):
490490
beta = at.as_tensor_variable(floatX(beta))
491491
return super().dist([q, beta], **kwargs)
492492

493+
def get_moment(rv, size, q, beta):
494+
median = at.power(at.log(0.5) / at.log(q), 1 / beta) - 1
495+
if not rv_size_is_none(size):
496+
median = at.full(size, median)
497+
return median
498+
493499
def logp(value, q, beta):
494500
r"""
495501
Calculate log-probability of DiscreteWeibull distribution at specified value.

pymc/tests/test_distributions_moments.py

+21-1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
DensityDist,
2323
Dirichlet,
2424
DiscreteUniform,
25+
DiscreteWeibull,
2526
ExGaussian,
2627
Exponential,
2728
Flat,
@@ -110,7 +111,6 @@ def test_all_distributions_have_moments():
110111

111112
# Distributions that have been refactored but don't yet have moments
112113
not_implemented |= {
113-
dist_module.discrete.DiscreteWeibull,
114114
dist_module.multivariate.DirichletMultinomial,
115115
dist_module.multivariate.Wishart,
116116
}
@@ -752,6 +752,26 @@ def test_discrete_uniform_moment(lower, upper, size, expected):
752752
assert_moment_is_expected(model, expected)
753753

754754

755+
@pytest.mark.parametrize(
756+
"q, beta, size, expected",
757+
[
758+
(0.5, 0.5, None, 0),
759+
(0.6, 0.1, 5, (20,) * 5),
760+
(np.linspace(0.25, 0.99, 4), 0.42, None, [0, 0, 6, 23862]),
761+
(
762+
np.linspace(0.5, 0.99, 3),
763+
[[1, 1.25, 1.75], [1.25, 0.75, 0.5]],
764+
None,
765+
[[0, 0, 10], [0, 2, 4755]],
766+
),
767+
],
768+
)
769+
def test_discrete_weibull_moment(q, beta, size, expected):
770+
with Model() as model:
771+
DiscreteWeibull("x", q=q, beta=beta, size=size)
772+
assert_moment_is_expected(model, expected)
773+
774+
755775
@pytest.mark.parametrize(
756776
"a, size, expected",
757777
[

0 commit comments

Comments
 (0)