Skip to content

Commit 9badc86

Browse files
authored
Compute from_sample() in a single pass over the data (#92284)
1 parent 6dcfd6c commit 9badc86

File tree

1 file changed

+27
-18
lines changed

1 file changed

+27
-18
lines changed

Lib/statistics.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -206,16 +206,17 @@ def _sum(data):
206206

207207

208208
def _ss(data, c=None):
209-
"""Return sum of square deviations of sequence data.
209+
"""Return the exact mean and sum of square deviations of sequence data.
210+
211+
Calculations are done in a single pass, allowing the input to be an iterator.
212+
213+
If given *c* is used the mean; otherwise, it is calculated from the data.
214+
Use the *c* argument with care, as it can lead to garbage results.
210215
211-
If ``c`` is None, the mean is calculated in one pass, and the deviations
212-
from the mean are calculated in a second pass. Otherwise, deviations are
213-
calculated from ``c`` as given. Use the second case with care, as it can
214-
lead to garbage results.
215216
"""
216217
if c is not None:
217-
T, total, count = _sum((d := x - c) * d for x in data)
218-
return (T, total, count)
218+
T, ssd, count = _sum((d := x - c) * d for x in data)
219+
return (T, ssd, c, count)
219220
count = 0
220221
types = set()
221222
types_add = types.add
@@ -228,20 +229,21 @@ def _ss(data, c=None):
228229
sx_partials[d] += n
229230
sxx_partials[d] += n * n
230231
if not count:
231-
total = Fraction(0)
232+
ssd = c = Fraction(0)
232233
elif None in sx_partials:
233234
# The sum will be a NAN or INF. We can ignore all the finite
234235
# partials, and just look at this special one.
235-
total = sx_partials[None]
236+
ssd = c = sx_partials[None]
236237
assert not _isfinite(total)
237238
else:
238239
sx = sum(Fraction(n, d) for d, n in sx_partials.items())
239240
sxx = sum(Fraction(n, d*d) for d, n in sxx_partials.items())
240241
# This formula has poor numeric properties for floats,
241242
# but with fractions it is exact.
242-
total = (count * sxx - sx * sx) / count
243+
ssd = (count * sxx - sx * sx) / count
244+
c = sx / count
243245
T = reduce(_coerce, types, int) # or raise TypeError
244-
return (T, total, count)
246+
return (T, ssd, c, count)
245247

246248

247249
def _isfinite(x):
@@ -854,7 +856,7 @@ def variance(data, xbar=None):
854856
Fraction(67, 108)
855857
856858
"""
857-
T, ss, n = _ss(data, xbar)
859+
T, ss, c, n = _ss(data, xbar)
858860
if n < 2:
859861
raise StatisticsError('variance requires at least two data points')
860862
return _convert(ss / (n - 1), T)
@@ -895,7 +897,7 @@ def pvariance(data, mu=None):
895897
Fraction(13, 72)
896898
897899
"""
898-
T, ss, n = _ss(data, mu)
900+
T, ss, c, n = _ss(data, mu)
899901
if n < 1:
900902
raise StatisticsError('pvariance requires at least one data point')
901903
return _convert(ss / n, T)
@@ -910,7 +912,7 @@ def stdev(data, xbar=None):
910912
1.0810874155219827
911913
912914
"""
913-
T, ss, n = _ss(data, xbar)
915+
T, ss, c, n = _ss(data, xbar)
914916
if n < 2:
915917
raise StatisticsError('stdev requires at least two data points')
916918
mss = ss / (n - 1)
@@ -928,7 +930,7 @@ def pstdev(data, mu=None):
928930
0.986893273527251
929931
930932
"""
931-
T, ss, n = _ss(data, mu)
933+
T, ss, c, n = _ss(data, mu)
932934
if n < 1:
933935
raise StatisticsError('pstdev requires at least one data point')
934936
mss = ss / n
@@ -937,6 +939,15 @@ def pstdev(data, mu=None):
937939
return _float_sqrt_of_frac(mss.numerator, mss.denominator)
938940

939941

942+
def _mean_stdev(data):
943+
"""In one pass, compute the mean and sample standard deviation as floats."""
944+
T, ss, xbar, n = _ss(data)
945+
if n < 2:
946+
raise StatisticsError('stdev requires at least two data points')
947+
mss = ss / (n - 1)
948+
return float(xbar), _float_sqrt_of_frac(mss.numerator, mss.denominator)
949+
950+
940951
# === Statistics for relations between two inputs ===
941952

942953
# See https://en.wikipedia.org/wiki/Covariance
@@ -1171,9 +1182,7 @@ def __init__(self, mu=0.0, sigma=1.0):
11711182
@classmethod
11721183
def from_samples(cls, data):
11731184
"Make a normal distribution instance from sample data."
1174-
if not isinstance(data, (list, tuple)):
1175-
data = list(data)
1176-
return cls(mean(data), stdev(data))
1185+
return cls(*_mean_stdev(data))
11771186

11781187
def samples(self, n, *, seed=None):
11791188
"Generate *n* samples for a given mean and standard deviation."

0 commit comments

Comments
 (0)