Skip to content

Commit e94dbb7

Browse files
committed
use the DataPoint class in the AverageLearner
1 parent 1c452b9 commit e94dbb7

File tree

1 file changed

+12
-24
lines changed

1 file changed

+12
-24
lines changed

adaptive/learner/average_learner.py

+12-24
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import numpy as np
66

7+
from adaptive.learner.average_mixin import DataPoint
78
from adaptive.learner.base_learner import BaseLearner
89
from adaptive.notebook_integration import ensure_holoviews
910
from adaptive.utils import cache_latest
@@ -40,14 +41,15 @@ def __init__(self, function, atol=None, rtol=None):
4041
if rtol is None:
4142
rtol = np.inf
4243

43-
self.data = {}
44+
self.data = DataPoint()
4445
self.pending_points = set()
4546
self.function = function
4647
self.atol = atol
4748
self.rtol = rtol
48-
self.npoints = 0
49-
self.sum_f = 0
50-
self.sum_f_sq = 0
49+
50+
@property
51+
def npoints(self):
52+
return self.data.n
5153

5254
@property
5355
def n_requested(self):
@@ -72,35 +74,23 @@ def ask(self, n, tell_pending=True):
7274

7375
def tell(self, n, value):
7476
if n in self.data:
75-
# The point has already been added before.
7677
return
77-
7878
self.data[n] = value
7979
self.pending_points.discard(n)
80-
self.sum_f += value
81-
self.sum_f_sq += value ** 2
82-
self.npoints += 1
8380

8481
def tell_pending(self, n):
8582
self.pending_points.add(n)
8683

8784
@property
8885
def mean(self):
8986
"""The average of all values in `data`."""
90-
return self.sum_f / self.npoints
87+
return self.data.mean
9188

9289
@property
9390
def std(self):
9491
"""The corrected sample standard deviation of the values
9592
in `data`."""
96-
n = self.npoints
97-
if n < 2:
98-
return np.inf
99-
numerator = self.sum_f_sq - n * self.mean ** 2
100-
if numerator < 0:
101-
# in this case the numerator ~ -1e-15
102-
return 0
103-
return sqrt(numerator / (n - 1))
93+
return self.data.std
10494

10595
@cache_latest
10696
def loss(self, real=True, *, n=None):
@@ -110,10 +100,8 @@ def loss(self, real=True, *, n=None):
110100
n = n
111101
if n < 2:
112102
return np.inf
113-
standard_error = self.std / sqrt(n)
114-
return max(
115-
standard_error / self.atol, standard_error / abs(self.mean) / self.rtol
116-
)
103+
sem = self.data.standard_error
104+
return max(sem / self.atol, sem / abs(self.mean) / self.rtol)
117105

118106
def _loss_improvement(self, n):
119107
loss = self.loss()
@@ -142,7 +130,7 @@ def plot(self):
142130
return hv.operation.histogram(vals, num_bins=num_bins, dimension=1)
143131

144132
def _get_data(self):
145-
return (self.data, self.npoints, self.sum_f, self.sum_f_sq)
133+
return dict(self.data)
146134

147135
def _set_data(self, data):
148-
self.data, self.npoints, self.sum_f, self.sum_f_sq = data
136+
self.data = DataPoint(data)

0 commit comments

Comments
 (0)