1
1
# -*- coding: utf-8 -*-
2
2
3
+ from collections import Mapping
3
4
from math import sqrt
4
5
import sys
5
6
6
7
import numpy as np
8
+ import scipy .stats
7
9
8
10
from .learner1D import Learner1D
9
11
14
16
class AverageMixin :
15
17
@property
16
18
def data (self ):
17
- return {k : sum ( v . values ()) / len ( v ) for k , v in self ._data .items ()}
19
+ return {k : v . mean for k , v in self ._data .items ()}
18
20
19
21
@property
20
22
def data_sem (self ):
21
- return {k : self .standard_error ( v . values ())
23
+ return {k : v .standard_error if v . n >= self . min_values_per_point else inf
22
24
for k , v in self ._data .items ()}
23
25
24
- def standard_error (self , lst ):
25
- n = len (lst )
26
- if n < self .min_values_per_point :
27
- return inf
28
- sum_f_sq = sum (x ** 2 for x in lst )
29
- mean = sum (x for x in lst ) / n
30
- numerator = sum_f_sq - n * mean ** 2
31
- if numerator < 0 :
32
- # This means that the numerator is ~ -1e-15
33
- return 0
34
- std = sqrt (numerator / (n - 1 ))
35
- return std / sqrt (n )
36
-
37
26
def mean_values_per_point (self ):
38
- return np .mean ([len ( x . values ()) for x in self ._data .values ()])
27
+ return np .mean ([x . n for x in self ._data .values ()])
39
28
40
29
def get_seed (self , point ):
41
30
_data = self ._data .get (point , {})
@@ -77,7 +66,7 @@ def _remove_from_to_pending(self, point):
77
66
def _add_to_data (self , point , value ):
78
67
x , seed = self .unpack_point (point )
79
68
if x not in self ._data :
80
- self ._data [x ] = {}
69
+ self ._data [x ] = DataPoint ()
81
70
self ._data [x ][seed ] = value
82
71
83
72
def ask (self , n , tell_pending = True ):
@@ -142,7 +131,7 @@ def needs_more_data(p):
142
131
143
132
144
133
def add_average_mixin (cls ):
145
- names = ('data' , 'data_sem' , 'standard_error' , ' mean_values_per_point' ,
134
+ names = ('data' , 'data_sem' , 'mean_values_per_point' ,
146
135
'get_seed' , 'loss_per_existing_point' , '_add_to_pending' ,
147
136
'_remove_from_to_pending' , '_add_to_data' , 'ask' , 'n_values' ,
148
137
'_normalize_new_points_loss_improvements' ,
@@ -153,3 +142,66 @@ def add_average_mixin(cls):
153
142
setattr (cls , name , getattr (AverageMixin , name ))
154
143
155
144
return cls
145
+
146
+
147
+ class DataPoint (dict ):
148
+ """A dict-like data structure that keeps track of the
149
+ length, sum, and sum of squares of the values.
150
+
151
+ It has properties to calculate the mean, sample
152
+ standard deviation, and standard error."""
153
+ def __init__ (self , * args , ** kwargs ):
154
+ self .sum = 0
155
+ self .sum_sq = 0
156
+ self .n = 0
157
+ self .update (* args , ** kwargs )
158
+
159
+ def __setitem__ (self , key , val ):
160
+ self ._remove (key )
161
+ self .sum += val
162
+ self .sum_sq += val ** 2
163
+ self .n += 1
164
+ super ().__setitem__ (key , val )
165
+
166
+ def _remove (self , key ):
167
+ if key in self :
168
+ val = self [key ]
169
+ self .sum -= val
170
+ self .sum_sq -= val ** 2
171
+ self .n -= 1
172
+
173
+ @property
174
+ def mean (self ):
175
+ return self .sum / self .n
176
+
177
+ @property
178
+ def std (self ):
179
+ numerator = self .sum_sq - self .n * self .mean ** 2
180
+ if numerator < 0 :
181
+ # This means that the numerator is ~ -1e-15
182
+ return 0
183
+ return sqrt (numerator / (self .n - 1 ))
184
+
185
+ @property
186
+ def standard_error (self ):
187
+ return self .std / sqrt (self .n )
188
+
189
+ def __delitem__ (self , key ):
190
+ self ._remove (key )
191
+ super ().__delitem__ (key )
192
+
193
+ def pop (self , * args ):
194
+ self ._remove (args [0 ])
195
+ return super ().pop (* args )
196
+
197
+ def update (self , other = None , ** kwargs ):
198
+ iterator = other if isinstance (other , Mapping ) else kwargs
199
+ for k , v in iterator .items ():
200
+ self [k ] = v
201
+
202
+ def assert_consistent_data_structure (self ):
203
+ vals = list (self .values ())
204
+ np .testing .assert_almost_equal (np .mean (vals ), self .mean )
205
+ np .testing .assert_almost_equal (np .std (vals , ddof = 1 ), self .std )
206
+ np .testing .assert_almost_equal (self .standard_error , scipy .stats .sem (vals ))
207
+ assert self .n == len (vals )
0 commit comments