Skip to content

Commit ed06ec1

Browse files
authored
pythonGH-81620: Add random.binomialvariate() (pythonGH-94719)
1 parent f5c02af commit ed06ec1

File tree

4 files changed

+175
-8
lines changed

4 files changed

+175
-8
lines changed

Doc/library/random.rst

+25-6
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,28 @@ Functions for sequences
258258
The *population* must be a sequence. Automatic conversion of sets
259259
to lists is no longer supported.
260260

261+
Discrete distributions
262+
----------------------
263+
264+
The following function generates a discrete distribution.
265+
266+
.. function:: binomialvariate(n=1, p=0.5)
267+
268+
`Binomial distribution
269+
<http://mathworld.wolfram.com/BinomialDistribution.html>`_.
270+
Return the number of successes for *n* independent trials with the
271+
probability of success in each trial being *p*:
272+
273+
Mathematically equivalent to::
274+
275+
sum(random() < p for i in range(n))
276+
277+
The number of trials *n* should be a non-negative integer.
278+
The probability of success *p* should be between ``0.0 <= p <= 1.0``.
279+
The result is an integer in the range ``0 <= X <= n``.
280+
281+
.. versionadded:: 3.12
282+
261283

262284
.. _real-valued-distributions:
263285

@@ -452,16 +474,13 @@ Simulations::
452474
>>> # Deal 20 cards without replacement from a deck
453475
>>> # of 52 playing cards, and determine the proportion of cards
454476
>>> # with a ten-value: ten, jack, queen, or king.
455-
>>> dealt = sample(['tens', 'low cards'], counts=[16, 36], k=20)
456-
>>> dealt.count('tens') / 20
477+
>>> deal = sample(['tens', 'low cards'], counts=[16, 36], k=20)
478+
>>> deal.count('tens') / 20
457479
0.15
458480

459481
>>> # Estimate the probability of getting 5 or more heads from 7 spins
460482
>>> # of a biased coin that settles on heads 60% of the time.
461-
>>> def trial():
462-
... return choices('HT', cum_weights=(0.60, 1.00), k=7).count('H') >= 5
463-
...
464-
>>> sum(trial() for i in range(10_000)) / 10_000
483+
>>> sum(binomialvariate(n=7, p=0.6) >= 5 for i in range(10_000)) / 10_000
465484
0.4169
466485

467486
>>> # Probability of the median of 5 samples being in middle two quartiles

Lib/random.py

+93-2
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
negative exponential
2525
gamma
2626
beta
27+
binomial
2728
pareto
2829
Weibull
2930
@@ -49,6 +50,7 @@
4950
from math import log as _log, exp as _exp, pi as _pi, e as _e, ceil as _ceil
5051
from math import sqrt as _sqrt, acos as _acos, cos as _cos, sin as _sin
5152
from math import tau as TWOPI, floor as _floor, isfinite as _isfinite
53+
from math import lgamma as _lgamma, fabs as _fabs
5254
from os import urandom as _urandom
5355
from _collections_abc import Sequence as _Sequence
5456
from operator import index as _index
@@ -68,6 +70,7 @@
6870
"Random",
6971
"SystemRandom",
7072
"betavariate",
73+
"binomialvariate",
7174
"choice",
7275
"choices",
7376
"expovariate",
@@ -725,6 +728,91 @@ def betavariate(self, alpha, beta):
725728
return y / (y + self.gammavariate(beta, 1.0))
726729
return 0.0
727730

731+
732+
def binomialvariate(self, n=1, p=0.5):
733+
"""Binomial random variable.
734+
735+
Gives the number of successes for *n* independent trials
736+
with the probability of success in each trial being *p*:
737+
738+
sum(random() < p for i in range(n))
739+
740+
Returns an integer in the range: 0 <= X <= n
741+
742+
"""
743+
# Error check inputs and handle edge cases
744+
if n < 0:
745+
raise ValueError("n must be non-negative")
746+
if p <= 0.0 or p >= 1.0:
747+
if p == 0.0:
748+
return 0
749+
if p == 1.0:
750+
return n
751+
raise ValueError("p must be in the range 0.0 <= p <= 1.0")
752+
753+
random = self.random
754+
755+
# Fast path for a common case
756+
if n == 1:
757+
return _index(random() < p)
758+
759+
# Exploit symmetry to establish: p <= 0.5
760+
if p > 0.5:
761+
return n - self.binomialvariate(n, 1.0 - p)
762+
763+
if n * p < 10.0:
764+
# BG: Geometric method by Devroye with running time of O(np).
765+
# https://dl.acm.org/doi/pdf/10.1145/42372.42381
766+
x = y = 0
767+
c = _log(1.0 - p)
768+
if not c:
769+
return x
770+
while True:
771+
y += _floor(_log(random()) / c) + 1
772+
if y > n:
773+
return x
774+
x += 1
775+
776+
# BTRS: Transformed rejection with squeeze method by Wolfgang Hörmann
777+
# https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.47.8407&rep=rep1&type=pdf
778+
assert n*p >= 10.0 and p <= 0.5
779+
setup_complete = False
780+
781+
spq = _sqrt(n * p * (1.0 - p)) # Standard deviation of the distribution
782+
b = 1.15 + 2.53 * spq
783+
a = -0.0873 + 0.0248 * b + 0.01 * p
784+
c = n * p + 0.5
785+
vr = 0.92 - 4.2 / b
786+
787+
while True:
788+
789+
u = random()
790+
v = random()
791+
u -= 0.5
792+
us = 0.5 - _fabs(u)
793+
k = _floor((2.0 * a / us + b) * u + c)
794+
if k < 0 or k > n:
795+
continue
796+
797+
# The early-out "squeeze" test substantially reduces
798+
# the number of acceptance condition evaluations.
799+
if us >= 0.07 and v <= vr:
800+
return k
801+
802+
# Acceptance-rejection test.
803+
# Note, the original paper errorneously omits the call to log(v)
804+
# when comparing to the log of the rescaled binomial distribution.
805+
if not setup_complete:
806+
alpha = (2.83 + 5.1 / b) * spq
807+
lpq = _log(p / (1.0 - p))
808+
m = _floor((n + 1) * p) # Mode of the distribution
809+
h = _lgamma(m + 1) + _lgamma(n - m + 1)
810+
setup_complete = True # Only needs to be done once
811+
v *= alpha / (a / (us * us) + b)
812+
if _log(v) <= h - _lgamma(k + 1) - _lgamma(n - k + 1) + (k - m) * lpq:
813+
return k
814+
815+
728816
def paretovariate(self, alpha):
729817
"""Pareto distribution. alpha is the shape parameter."""
730818
# Jain, pg. 495
@@ -810,6 +898,7 @@ def _notimplemented(self, *args, **kwds):
810898
gammavariate = _inst.gammavariate
811899
gauss = _inst.gauss
812900
betavariate = _inst.betavariate
901+
binomialvariate = _inst.binomialvariate
813902
paretovariate = _inst.paretovariate
814903
weibullvariate = _inst.weibullvariate
815904
getstate = _inst.getstate
@@ -834,15 +923,17 @@ def _test_generator(n, func, args):
834923
low = min(data)
835924
high = max(data)
836925

837-
print(f'{t1 - t0:.3f} sec, {n} times {func.__name__}')
926+
print(f'{t1 - t0:.3f} sec, {n} times {func.__name__}{args!r}')
838927
print('avg %g, stddev %g, min %g, max %g\n' % (xbar, sigma, low, high))
839928

840929

841-
def _test(N=2000):
930+
def _test(N=10_000):
842931
_test_generator(N, random, ())
843932
_test_generator(N, normalvariate, (0.0, 1.0))
844933
_test_generator(N, lognormvariate, (0.0, 1.0))
845934
_test_generator(N, vonmisesvariate, (0.0, 1.0))
935+
_test_generator(N, binomialvariate, (15, 0.60))
936+
_test_generator(N, binomialvariate, (100, 0.75))
846937
_test_generator(N, gammavariate, (0.01, 1.0))
847938
_test_generator(N, gammavariate, (0.1, 1.0))
848939
_test_generator(N, gammavariate, (0.1, 2.0))

Lib/test/test_random.py

+56
Original file line numberDiff line numberDiff line change
@@ -1045,13 +1045,69 @@ def test_constant(self):
10451045
(g.lognormvariate, (0.0, 0.0), 1.0),
10461046
(g.lognormvariate, (-float('inf'), 0.0), 0.0),
10471047
(g.normalvariate, (10.0, 0.0), 10.0),
1048+
(g.binomialvariate, (0, 0.5), 0),
1049+
(g.binomialvariate, (10, 0.0), 0),
1050+
(g.binomialvariate, (10, 1.0), 10),
10481051
(g.paretovariate, (float('inf'),), 1.0),
10491052
(g.weibullvariate, (10.0, float('inf')), 10.0),
10501053
(g.weibullvariate, (0.0, 10.0), 0.0),
10511054
]:
10521055
for i in range(N):
10531056
self.assertEqual(variate(*args), expected)
10541057

1058+
def test_binomialvariate(self):
1059+
B = random.binomialvariate
1060+
1061+
# Cover all the code paths
1062+
with self.assertRaises(ValueError):
1063+
B(n=-1) # Negative n
1064+
with self.assertRaises(ValueError):
1065+
B(n=1, p=-0.5) # Negative p
1066+
with self.assertRaises(ValueError):
1067+
B(n=1, p=1.5) # p > 1.0
1068+
self.assertEqual(B(10, 0.0), 0) # p == 0.0
1069+
self.assertEqual(B(10, 1.0), 10) # p == 1.0
1070+
self.assertTrue(B(1, 0.3) in {0, 1}) # n == 1 fast path
1071+
self.assertTrue(B(1, 0.9) in {0, 1}) # n == 1 fast path
1072+
self.assertTrue(B(1, 0.0) in {0}) # n == 1 fast path
1073+
self.assertTrue(B(1, 1.0) in {1}) # n == 1 fast path
1074+
1075+
# BG method p <= 0.5 and n*p=1.25
1076+
self.assertTrue(B(5, 0.25) in set(range(6)))
1077+
1078+
# BG method p >= 0.5 and n*(1-p)=1.25
1079+
self.assertTrue(B(5, 0.75) in set(range(6)))
1080+
1081+
# BTRS method p <= 0.5 and n*p=25
1082+
self.assertTrue(B(100, 0.25) in set(range(101)))
1083+
1084+
# BTRS method p > 0.5 and n*(1-p)=25
1085+
self.assertTrue(B(100, 0.75) in set(range(101)))
1086+
1087+
# Statistical tests chosen such that they are
1088+
# exceedingly unlikely to ever fail for correct code.
1089+
1090+
# BG code path
1091+
# Expected dist: [31641, 42188, 21094, 4688, 391]
1092+
c = Counter(B(4, 0.25) for i in range(100_000))
1093+
self.assertTrue(29_641 <= c[0] <= 33_641, c)
1094+
self.assertTrue(40_188 <= c[1] <= 44_188)
1095+
self.assertTrue(19_094 <= c[2] <= 23_094)
1096+
self.assertTrue(2_688 <= c[3] <= 6_688)
1097+
self.assertEqual(set(c), {0, 1, 2, 3, 4})
1098+
1099+
# BTRS code path
1100+
# Sum of c[20], c[21], c[22], c[23], c[24] expected to be 36,214
1101+
c = Counter(B(100, 0.25) for i in range(100_000))
1102+
self.assertTrue(34_214 <= c[20]+c[21]+c[22]+c[23]+c[24] <= 38_214)
1103+
self.assertTrue(set(c) <= set(range(101)))
1104+
self.assertEqual(c.total(), 100_000)
1105+
1106+
# Demonstrate the BTRS works for huge values of n
1107+
self.assertTrue(19_000_000 <= B(100_000_000, 0.2) <= 21_000_000)
1108+
self.assertTrue(89_000_000 <= B(100_000_000, 0.9) <= 91_000_000)
1109+
1110+
10551111
def test_von_mises_range(self):
10561112
# Issue 17149: von mises variates were not consistently in the
10571113
# range [0, 2*PI].
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add random.binomialvariate().

0 commit comments

Comments
 (0)