Skip to content

Commit d22994a

Browse files
committed
add a DataPoint class which is a dict-like datastructure
that keeps track of the length, sum, and sum of squares of the values. It has properties to calculate the mean, sample standard deviation, and standard error.
1 parent efe706d commit d22994a

File tree

1 file changed

+70
-18
lines changed

1 file changed

+70
-18
lines changed

adaptive/learner/average_mixin.py

+70-18
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
# -*- coding: utf-8 -*-
22

3+
from collections import Mapping
34
from math import sqrt
45
import sys
56

67
import numpy as np
8+
import scipy.stats
79

810
from .learner1D import Learner1D
911

@@ -14,28 +16,15 @@
1416
class AverageMixin:
1517
@property
1618
def data(self):
17-
return {k: sum(v.values()) / len(v) for k, v in self._data.items()}
19+
return {k: v.mean for k, v in self._data.items()}
1820

1921
@property
2022
def data_sem(self):
21-
return {k: self.standard_error(v.values())
23+
return {k: v.standard_error if v.n >= self.min_values_per_point else inf
2224
for k, v in self._data.items()}
2325

24-
def standard_error(self, lst):
25-
n = len(lst)
26-
if n < self.min_values_per_point:
27-
return inf
28-
sum_f_sq = sum(x**2 for x in lst)
29-
mean = sum(x for x in lst) / n
30-
numerator = sum_f_sq - n * mean**2
31-
if numerator < 0:
32-
# This means that the numerator is ~ -1e-15
33-
return 0
34-
std = sqrt(numerator / (n - 1))
35-
return std / sqrt(n)
36-
3726
def mean_values_per_point(self):
38-
return np.mean([len(x.values()) for x in self._data.values()])
27+
return np.mean([x.n for x in self._data.values()])
3928

4029
def get_seed(self, point):
4130
_data = self._data.get(point, {})
@@ -77,7 +66,7 @@ def _remove_from_to_pending(self, point):
7766
def _add_to_data(self, point, value):
7867
x, seed = self.unpack_point(point)
7968
if x not in self._data:
80-
self._data[x] = {}
69+
self._data[x] = DataPoint()
8170
self._data[x][seed] = value
8271

8372
def ask(self, n, tell_pending=True):
@@ -142,7 +131,7 @@ def needs_more_data(p):
142131

143132

144133
def add_average_mixin(cls):
145-
names = ('data', 'data_sem', 'standard_error', 'mean_values_per_point',
134+
names = ('data', 'data_sem', 'mean_values_per_point',
146135
'get_seed', 'loss_per_existing_point', '_add_to_pending',
147136
'_remove_from_to_pending', '_add_to_data', 'ask', 'n_values',
148137
'_normalize_new_points_loss_improvements',
@@ -153,3 +142,66 @@ def add_average_mixin(cls):
153142
setattr(cls, name, getattr(AverageMixin, name))
154143

155144
return cls
145+
146+
147+
class DataPoint(dict):
148+
"""A dict-like data structure that keeps track of the
149+
length, sum, and sum of squares of the values.
150+
151+
It has properties to calculate the mean, sample
152+
standard deviation, and standard error."""
153+
def __init__(self, *args, **kwargs):
154+
self.sum = 0
155+
self.sum_sq = 0
156+
self.n = 0
157+
self.update(*args, **kwargs)
158+
159+
def __setitem__(self, key, val):
160+
self._remove(key)
161+
self.sum += val
162+
self.sum_sq += val**2
163+
self.n += 1
164+
super().__setitem__(key, val)
165+
166+
def _remove(self, key):
167+
if key in self:
168+
val = self[key]
169+
self.sum -= val
170+
self.sum_sq -= val**2
171+
self.n -= 1
172+
173+
@property
174+
def mean(self):
175+
return self.sum / self.n
176+
177+
@property
178+
def std(self):
179+
numerator = self.sum_sq - self.n * self.mean**2
180+
if numerator < 0:
181+
# This means that the numerator is ~ -1e-15
182+
return 0
183+
return sqrt(numerator / (self.n - 1))
184+
185+
@property
186+
def standard_error(self):
187+
return self.std / sqrt(self.n)
188+
189+
def __delitem__(self, key):
190+
self._remove(key)
191+
super().__delitem__(key)
192+
193+
def pop(self, *args):
194+
self._remove(args[0])
195+
return super().pop(*args)
196+
197+
def update(self, other=None, **kwargs):
198+
iterator = other if isinstance(other, Mapping) else kwargs
199+
for k, v in iterator.items():
200+
self[k] = v
201+
202+
def assert_consistent_data_structure(self):
203+
vals = list(self.values())
204+
np.testing.assert_almost_equal(np.mean(vals), self.mean)
205+
np.testing.assert_almost_equal(np.std(vals, ddof=1), self.std)
206+
np.testing.assert_almost_equal(self.standard_error, scipy.stats.sem(vals))
207+
assert self.n == len(vals)

0 commit comments

Comments
 (0)