@@ -26,13 +26,14 @@ def data_sem(self):
26
26
def mean_values_per_point (self ):
27
27
return np .mean ([x .n for x in self ._data .values ()])
28
28
29
- def get_seed (self , point ):
29
+ def _next_seed (self , point ):
30
30
_data = self ._data .get (point , {})
31
31
pending_seeds = self .pending_points .get (point , set ())
32
32
seed = len (_data ) + len (pending_seeds )
33
33
if seed in _data or seed in pending_seeds :
34
- # means that the seed already exists, for example
34
+ # Means that the seed already exists, for example
35
35
# when '_data[point].keys() | pending_points[point] == {0, 2}'.
36
+ # Only happens when starting the learner after cancelling/loading.
36
37
return (set (range (seed )) - pending_seeds - _data .keys ()).pop ()
37
38
return seed
38
39
@@ -42,7 +43,7 @@ def loss_per_existing_point(self):
42
43
points = []
43
44
loss_improvements = []
44
45
for p , sem in self .data_sem .items ():
45
- points .append ((p , self .get_seed (p )))
46
+ points .append ((p , self ._next_seed (p )))
46
47
N = self .n_values (p )
47
48
sem_improvement = (1 - sqrt (N - 1 ) / sqrt (N )) * sem
48
49
loss_improvement = self .weight * sem_improvement / scale
@@ -136,7 +137,7 @@ def _get_data(self):
136
137
137
138
def add_average_mixin (cls ):
138
139
names = ('data' , 'data_sem' , 'mean_values_per_point' ,
139
- 'get_seed ' , 'loss_per_existing_point' , '_add_to_pending' ,
140
+ '_next_seed ' , 'loss_per_existing_point' , '_add_to_pending' ,
140
141
'_remove_from_to_pending' , '_add_to_data' , 'ask' , 'n_values' ,
141
142
'_normalize_new_points_loss_improvements' ,
142
143
'_normalize_existing_points_loss_improvements' ,
0 commit comments