14
14
class AverageMixin :
15
15
@property
16
16
def data (self ):
17
- return {k : sum ( v . values ()) / len ( v ) for k , v in self ._data .items ()}
17
+ return {k : v . mean for k , v in self ._data .items ()}
18
18
19
19
@property
20
20
def data_sem (self ):
21
- return {k : self .standard_error ( v . values ())
21
+ return {k : v .standard_error if v . n >= self . min_values_per_point else inf
22
22
for k , v in self ._data .items ()}
23
23
24
- def standard_error (self , lst ):
25
- n = len (lst )
26
- if n < self .min_values_per_point :
27
- return inf
28
- sum_f_sq = sum (x ** 2 for x in lst )
29
- mean = sum (x for x in lst ) / n
30
- numerator = sum_f_sq - n * mean ** 2
31
- if numerator < 0 :
32
- # This means that the numerator is ~ -1e-15
33
- return 0
34
- std = sqrt (numerator / (n - 1 ))
35
- return std / sqrt (n )
36
-
37
24
def mean_values_per_point (self ):
38
- return np .mean ([len ( x . values ()) for x in self ._data .values ()])
25
+ return np .mean ([x . n for x in self ._data .values ()])
39
26
40
27
def get_seed (self , point ):
41
28
_data = self ._data .get (point , {})
@@ -77,7 +64,7 @@ def _remove_from_to_pending(self, point):
77
64
def _add_to_data (self , point , value ):
78
65
x , seed = self .unpack_point (point )
79
66
if x not in self ._data :
80
- self ._data [x ] = {}
67
+ self ._data [x ] = DataPoint ()
81
68
self ._data [x ][seed ] = value
82
69
83
70
def ask (self , n , tell_pending = True ):
@@ -142,7 +129,7 @@ def needs_more_data(p):
142
129
143
130
144
131
def add_average_mixin (cls ):
145
- names = ('data' , 'data_sem' , 'standard_error' , ' mean_values_per_point' ,
132
+ names = ('data' , 'data_sem' , 'mean_values_per_point' ,
146
133
'get_seed' , 'loss_per_existing_point' , '_add_to_pending' ,
147
134
'_remove_from_to_pending' , '_add_to_data' , 'ask' , 'n_values' ,
148
135
'_normalize_new_points_loss_improvements' ,
@@ -153,3 +140,63 @@ def add_average_mixin(cls):
153
140
setattr (cls , name , getattr (AverageMixin , name ))
154
141
155
142
return cls
143
+
144
+
145
+ class DataPoint (dict ):
146
+ """A dict-like data structure that keeps track of the
147
+ length, sum, and sum of squares of the values.
148
+
149
+ It has properties to calculate the mean, sample
150
+ standard deviation, and standard error."""
151
+ def __init__ (self , * args , ** kwargs ):
152
+ self .update (* args , ** kwargs )
153
+ self .sum = 0
154
+ self .sum_sq = 0
155
+ self .n = 0
156
+
157
+ def __setitem__ (self , key , val ):
158
+ self ._remove (key )
159
+ self .sum += val
160
+ self .sum_sq += val ** 2
161
+ self .n += 1
162
+ super ().__setitem__ (key , val )
163
+
164
+ def _remove (self , key ):
165
+ if key in self :
166
+ val = self [key ]
167
+ self .sum -= val
168
+ self .sum_sq -= val ** 2
169
+ self .n -= 1
170
+
171
+ @property
172
+ def mean (self ):
173
+ return self .sum / self .n
174
+
175
+ @property
176
+ def std (self ):
177
+ numerator = self .sum_sq - self .n * self .mean ** 2
178
+ if numerator < 0 :
179
+ # This means that the numerator is ~ -1e-15
180
+ return 0
181
+ return sqrt (numerator / (self .n - 1 ))
182
+
183
+ @property
184
+ def standard_error (self ):
185
+ return self .std / sqrt (self .n )
186
+
187
+ def __delitem__ (self , key ):
188
+ self ._remove (key )
189
+ super ().__delitem__ (key )
190
+
191
+ def pop (self , * args ):
192
+ self ._remove (args [0 ])
193
+ return super ().pop (* args )
194
+
195
+ def check (self ):
196
+ import numpy
197
+ import scipy .stats
198
+ vals = list (self .values ())
199
+ numpy .testing .assert_almost_equal (numpy .mean (vals ), self .mean )
200
+ numpy .testing .assert_almost_equal (numpy .std (vals , ddof = 1 ), self .std )
201
+ numpy .testing .assert_almost_equal (self .standard_error , scipy .stats .sem (vals ))
202
+ assert self .n == len (vals )
0 commit comments