Skip to content

Commit 1c60897

Browse files
authored
Merge pull request #274 from python-adaptive/feature/min_npoints
AverageLearner: implement min_npoints
2 parents 9c6668b + 52a7252 commit 1c60897

File tree

2 files changed

+23
-5
lines changed

2 files changed

+23
-5
lines changed

adaptive/learner/average_learner.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ class AverageLearner(BaseLearner):
1919
Desired absolute tolerance.
2020
rtol : float
2121
Desired relative tolerance.
22+
min_npoints : int
23+
Minimum number of points to sample.
2224
2325
Attributes
2426
----------
@@ -30,7 +32,7 @@ class AverageLearner(BaseLearner):
3032
Number of evaluated points.
3133
"""
3234

33-
def __init__(self, function, atol=None, rtol=None):
35+
def __init__(self, function, atol=None, rtol=None, min_npoints=2):
3436
if atol is None and rtol is None:
3537
raise Exception("At least one of `atol` and `rtol` should be set.")
3638
if atol is None:
@@ -44,6 +46,8 @@ def __init__(self, function, atol=None, rtol=None):
4446
self.atol = atol
4547
self.rtol = rtol
4648
self.npoints = 0
49+
# Cannot estimate standard deviation with fewer than 2 points.
50+
self.min_npoints = max(min_npoints, 2)
4751
self.sum_f = 0
4852
self.sum_f_sq = 0
4953

@@ -92,7 +96,7 @@ def std(self):
9296
"""The corrected sample standard deviation of the values
9397
in `data`."""
9498
n = self.npoints
95-
if n < 2:
99+
if n < self.min_npoints:
96100
return np.inf
97101
numerator = self.sum_f_sq - n * self.mean ** 2
98102
if numerator < 0:
@@ -106,7 +110,7 @@ def loss(self, real=True, *, n=None):
106110
n = self.npoints if real else self.n_requested
107111
else:
108112
n = n
109-
if n < 2:
113+
if n < self.min_npoints:
110114
return np.inf
111115
standard_error = self.std / sqrt(n)
112116
return max(
@@ -150,10 +154,11 @@ def __getstate__(self):
150154
self.function,
151155
self.atol,
152156
self.rtol,
157+
self.min_npoints,
153158
self._get_data(),
154159
)
155160

156161
def __setstate__(self, state):
157-
function, atol, rtol, data = state
158-
self.__init__(function, atol, rtol)
162+
function, atol, rtol, min_npoints, data = state
163+
self.__init__(function, atol, rtol, min_npoints)
159164
self._set_data(data)

adaptive/tests/test_average_learner.py

+13
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import numpy as np
55

66
from adaptive.learner import AverageLearner
7+
from adaptive.runner import simple
78

89

910
def test_only_returns_new_points():
@@ -46,3 +47,15 @@ def test_avg_std_and_npoints():
4647
assert learner.npoints == len(learner.data)
4748
assert abs(learner.sum_f - values.sum()) < 1e-13
4849
assert abs(learner.std - std) < 1e-13
50+
51+
52+
def test_min_npoints():
53+
def constant_function(seed):
54+
return 0.1
55+
56+
for min_npoints in [1, 2, 3]:
57+
learner = AverageLearner(
58+
constant_function, atol=0.01, rtol=0.01, min_npoints=min_npoints
59+
)
60+
simple(learner, lambda l: l.loss() < 1)
61+
assert learner.npoints >= max(2, min_npoints)

0 commit comments

Comments
 (0)