@@ -19,6 +19,8 @@ class AverageLearner(BaseLearner):
19
19
Desired absolute tolerance.
20
20
rtol : float
21
21
Desired relative tolerance.
22
+ min_npoints : int
23
+ Minimum number of points to sample.
22
24
23
25
Attributes
24
26
----------
@@ -30,7 +32,7 @@ class AverageLearner(BaseLearner):
30
32
Number of evaluated points.
31
33
"""
32
34
33
- def __init__ (self , function , atol = None , rtol = None ):
35
+ def __init__ (self , function , atol = None , rtol = None , min_npoints = 2 ):
34
36
if atol is None and rtol is None :
35
37
raise Exception ("At least one of `atol` and `rtol` should be set." )
36
38
if atol is None :
@@ -44,6 +46,8 @@ def __init__(self, function, atol=None, rtol=None):
44
46
self .atol = atol
45
47
self .rtol = rtol
46
48
self .npoints = 0
49
+ # Cannot estimate standard deviation with fewer than 2 points.
50
+ self .min_npoints = max (min_npoints , 2 )
47
51
self .sum_f = 0
48
52
self .sum_f_sq = 0
49
53
@@ -92,7 +96,7 @@ def std(self):
92
96
"""The corrected sample standard deviation of the values
93
97
in `data`."""
94
98
n = self .npoints
95
- if n < 2 :
99
+ if n < self . min_npoints :
96
100
return np .inf
97
101
numerator = self .sum_f_sq - n * self .mean ** 2
98
102
if numerator < 0 :
@@ -106,7 +110,7 @@ def loss(self, real=True, *, n=None):
106
110
n = self .npoints if real else self .n_requested
107
111
else :
108
112
n = n
109
- if n < 2 :
113
+ if n < self . min_npoints :
110
114
return np .inf
111
115
standard_error = self .std / sqrt (n )
112
116
return max (
@@ -150,10 +154,11 @@ def __getstate__(self):
150
154
self .function ,
151
155
self .atol ,
152
156
self .rtol ,
157
+ self .min_npoints ,
153
158
self ._get_data (),
154
159
)
155
160
156
161
def __setstate__ (self , state ):
157
- function , atol , rtol , data = state
158
- self .__init__ (function , atol , rtol )
162
+ function , atol , rtol , min_npoints , data = state
163
+ self .__init__ (function , atol , rtol , min_npoints )
159
164
self ._set_data (data )
0 commit comments