Skip to content

Commit 7f307b9

Browse files
authored
Cap the values that Beta.random can generate. (#3924)
* TST: Added test for the failing sampler * BUG: Make Beta.random to use the clipped random number generator * DOC: Added to release notes * Fix the PR link in the release notes * FIX: use scipy.stats and change ref_rand to point to clipped_beta_rvs * FIX: Use np.maximum and np.minimum to work with scalars and arrays
1 parent e47b98a commit 7f307b9

File tree

5 files changed

+62
-3
lines changed

5 files changed

+62
-3
lines changed

Diff for: RELEASE-NOTES.md

+1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
- End of sampling report now uses `arviz.InferenceData` internally and avoids storing
2626
pointwise log likelihood (see [#3883](https://github.com/pymc-devs/pymc3/pull/3883)).
2727
- The multiprocessing start method on MacOS is now set to "forkserver", to avoid crashes (see issue [#3849](https://github.com/pymc-devs/pymc3/issues/3849), solved by [#3919](https://github.com/pymc-devs/pymc3/pull/3919)).
28+
- Forced the `Beta` distribution's `random` method to generate samples that are in the open interval $(0, 1)$, i.e. no value can be equal to zero or equal to one (issue [#3898](https://github.com/pymc-devs/pymc3/issues/3898) fixed by [#3924](https://github.com/pymc-devs/pymc3/pull/3924)).
2829

2930
### Deprecations
3031
- Remove `sample_ppc` and `sample_ppc_w` that were deprecated in 3.6.

Diff for: pymc3/distributions/continuous.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from .dist_math import (
3333
alltrue_elemwise, betaln, bound, gammaln, i0e, incomplete_beta, logpow,
3434
normal_lccdf, normal_lcdf, SplineWrapper, std_cdf, zvalue,
35+
clipped_beta_rvs,
3536
)
3637
from .distribution import (Continuous, draw_values, generate_samples)
3738

@@ -1290,7 +1291,7 @@ def random(self, point=None, size=None):
12901291
"""
12911292
alpha, beta = draw_values([self.alpha, self.beta],
12921293
point=point, size=size)
1293-
return generate_samples(stats.beta.rvs, alpha, beta,
1294+
return generate_samples(clipped_beta_rvs, alpha, beta,
12941295
dist_shape=self.shape,
12951296
size=size)
12961297

Diff for: pymc3/distributions/dist_math.py

+46
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
'''
2020
import numpy as np
2121
import scipy.linalg
22+
import scipy.stats
2223
import theano.tensor as tt
2324
import theano
2425
from theano.scalar import UnaryScalarOp, upgrade_to_float_no_complex
@@ -33,6 +34,11 @@
3334

3435
f = floatX
3536
c = - .5 * np.log(2. * np.pi)
37+
_beta_clip_values = {
38+
dtype: (np.nextafter(0, 1, dtype=dtype), np.nextafter(1, 0, dtype=dtype))
39+
for dtype in ["float16", "float32", "float64", "float128"]
40+
}
41+
3642

3743

3844
def bound(logp, *conditions, **kwargs):
@@ -548,3 +554,43 @@ def incomplete_beta(a, b, value):
548554
tt.and_(tt.le(b * value, one), tt.le(value, 0.95)),
549555
ps,
550556
t))
557+
558+
559+
def clipped_beta_rvs(a, b, size=None, dtype="float64"):
560+
"""Draw beta distributed random samples in the open :math:`(0, 1)` interval.
561+
562+
The samples are generated with ``scipy.stats.beta.rvs``, but any value that
563+
is equal to 0 or 1 will be shifted towards the next floating point in the
564+
interval :math:`[0, 1]`, depending on the floating point precision that is
565+
given by ``dtype``.
566+
567+
Parameters
568+
----------
569+
a : float or array_like of floats
570+
Alpha, strictly positive (>0).
571+
b : float or array_like of floats
572+
Beta, strictly positive (>0).
573+
size : int or tuple of ints, optional
574+
Output shape. If the given shape is, e.g., ``(m, n, k)``, then
575+
``m * n * k`` samples are drawn. If size is ``None`` (default),
576+
a single value is returned if ``a`` and ``b`` are both scalars.
577+
Otherwise, ``np.broadcast(a, b).size`` samples are drawn.
578+
dtype : str or dtype instance
579+
The floating point precision that the samples should have. This also
580+
determines the value that will be used to shift any samples returned
581+
by the numpy random number generator that are zero or one.
582+
583+
Returns
584+
-------
585+
out : ndarray or scalar
586+
Drawn samples from the parameterized beta distribution. The scipy
587+
implementation can yield values that are equal to zero or one. We
588+
assume the support of the Beta distribution to be in the open interval
589+
:math:`(0, 1)`, so we shift any sample that is equal to 0 to
590+
``np.nextafter(0, 1, dtype=dtype)`` and any sample that is equal to 1
591+
is shifted to ``np.nextafter(1, 0, dtype=dtype)``.
592+
593+
"""
594+
out = scipy.stats.beta.rvs(a, b, size=size).astype(dtype)
595+
lower, upper = _beta_clip_values[dtype]
596+
return np.maximum(np.minimum(out, upper), lower)

Diff for: pymc3/tests/test_dist_math.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@
2424
from ..theanof import floatX
2525
from ..distributions import Discrete
2626
from ..distributions.dist_math import (
27-
bound, factln, alltrue_scalar, MvNormalLogp, SplineWrapper, i0e)
27+
bound, factln, alltrue_scalar, MvNormalLogp, SplineWrapper, i0e,
28+
clipped_beta_rvs,
29+
)
2830

2931

3032
def test_bound():
@@ -216,3 +218,11 @@ def test_grad(self):
216218
utt.verify_grad(i0e, [-2.])
217219
utt.verify_grad(i0e, [[0.5, -2.]])
218220
utt.verify_grad(i0e, [[[0.5, -2.]]])
221+
222+
223+
@pytest.mark.parametrize("dtype", ["float16", "float32", "float64", "float128"])
224+
def test_clipped_beta_rvs(dtype):
225+
# Verify that the samples drawn from the beta distribution are never
226+
# equal to zero or one (issue #3898)
227+
values = clipped_beta_rvs(0.01, 0.01, size=1000000, dtype=dtype)
228+
assert not (np.any(values == 0) or np.any(values == 1))

Diff for: pymc3/tests/test_distributions_random.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import theano
2323

2424
import pymc3 as pm
25+
from pymc3.distributions.dist_math import clipped_beta_rvs
2526
from pymc3.distributions.distribution import (draw_values,
2627
_DrawValuesContext,
2728
_DrawValuesContextBlocker)
@@ -548,7 +549,7 @@ def ref_rand(size, mu, lam, alpha):
548549

549550
def test_beta(self):
550551
def ref_rand(size, alpha, beta):
551-
return st.beta.rvs(a=alpha, b=beta, size=size)
552+
return clipped_beta_rvs(a=alpha, b=beta, size=size)
552553
pymc3_random(pm.Beta, {'alpha': Rplus, 'beta': Rplus}, ref_rand=ref_rand)
553554

554555
def test_exponential(self):

0 commit comments

Comments
 (0)