Skip to content

Commit 79c9edf

Browse files
committed
add a DataPoint class which is a dict-like datastructure
that keeps track of the length, sum, and sum of squares of the values. It has properties to calculate the mean, sample standard deviation, and standard error.
1 parent efe706d commit 79c9edf

File tree

1 file changed

+65
-18
lines changed

1 file changed

+65
-18
lines changed

adaptive/learner/average_mixin.py

+65-18
Original file line numberDiff line numberDiff line change
@@ -14,28 +14,15 @@
1414
class AverageMixin:
1515
@property
1616
def data(self):
17-
return {k: sum(v.values()) / len(v) for k, v in self._data.items()}
17+
return {k: v.mean for k, v in self._data.items()}
1818

1919
@property
2020
def data_sem(self):
21-
return {k: self.standard_error(v.values())
21+
return {k: v.standard_error if v.n >= self.min_values_per_point else inf
2222
for k, v in self._data.items()}
2323

24-
def standard_error(self, lst):
25-
n = len(lst)
26-
if n < self.min_values_per_point:
27-
return inf
28-
sum_f_sq = sum(x**2 for x in lst)
29-
mean = sum(x for x in lst) / n
30-
numerator = sum_f_sq - n * mean**2
31-
if numerator < 0:
32-
# This means that the numerator is ~ -1e-15
33-
return 0
34-
std = sqrt(numerator / (n - 1))
35-
return std / sqrt(n)
36-
3724
def mean_values_per_point(self):
38-
return np.mean([len(x.values()) for x in self._data.values()])
25+
return np.mean([x.n for x in self._data.values()])
3926

4027
def get_seed(self, point):
4128
_data = self._data.get(point, {})
@@ -77,7 +64,7 @@ def _remove_from_to_pending(self, point):
7764
def _add_to_data(self, point, value):
7865
x, seed = self.unpack_point(point)
7966
if x not in self._data:
80-
self._data[x] = {}
67+
self._data[x] = DataPoint()
8168
self._data[x][seed] = value
8269

8370
def ask(self, n, tell_pending=True):
@@ -142,7 +129,7 @@ def needs_more_data(p):
142129

143130

144131
def add_average_mixin(cls):
145-
names = ('data', 'data_sem', 'standard_error', 'mean_values_per_point',
132+
names = ('data', 'data_sem', 'mean_values_per_point',
146133
'get_seed', 'loss_per_existing_point', '_add_to_pending',
147134
'_remove_from_to_pending', '_add_to_data', 'ask', 'n_values',
148135
'_normalize_new_points_loss_improvements',
@@ -153,3 +140,63 @@ def add_average_mixin(cls):
153140
setattr(cls, name, getattr(AverageMixin, name))
154141

155142
return cls
143+
144+
145+
class DataPoint(dict):
146+
"""A dict-like data structure that keeps track of the
147+
length, sum, and sum of squares of the values.
148+
149+
It has properties to calculate the mean, sample
150+
standard deviation, and standard error."""
151+
def __init__(self, *args, **kwargs):
152+
self.update(*args, **kwargs)
153+
self.sum = 0
154+
self.sum_sq = 0
155+
self.n = 0
156+
157+
def __setitem__(self, key, val):
158+
self._remove(key)
159+
self.sum += val
160+
self.sum_sq += val**2
161+
self.n += 1
162+
super().__setitem__(key, val)
163+
164+
def _remove(self, key):
165+
if key in self:
166+
val = self[key]
167+
self.sum -= val
168+
self.sum_sq -= val**2
169+
self.n -= 1
170+
171+
@property
172+
def mean(self):
173+
return self.sum / self.n
174+
175+
@property
176+
def std(self):
177+
numerator = self.sum_sq - self.n * self.mean**2
178+
if numerator < 0:
179+
# This means that the numerator is ~ -1e-15
180+
return 0
181+
return sqrt(numerator / (self.n - 1))
182+
183+
@property
184+
def standard_error(self):
185+
return self.std / sqrt(self.n)
186+
187+
def __delitem__(self, key):
188+
self._remove(key)
189+
super().__delitem__(key)
190+
191+
def pop(self, *args):
192+
self._remove(args[0])
193+
return super().pop(*args)
194+
195+
def check(self):
196+
import numpy
197+
import scipy.stats
198+
vals = list(self.values())
199+
numpy.testing.assert_almost_equal(numpy.mean(vals), self.mean)
200+
numpy.testing.assert_almost_equal(numpy.std(vals, ddof=1), self.std)
201+
numpy.testing.assert_almost_equal(self.standard_error, scipy.stats.sem(vals))
202+
assert self.n == len(vals)

0 commit comments

Comments
 (0)