@@ -26,31 +26,18 @@ def data_sem(self):
26
26
def mean_values_per_point (self ):
27
27
return np .mean ([x .n for x in self ._data .values ()])
28
28
29
- def _next_seed (self , point ):
29
+ def _next_seed (self , point , exclude = None ):
30
+ exclude = set (exclude ) if exclude is not None else set ()
30
31
_data = self ._data .get (point , {})
31
32
pending_seeds = self .pending_points .get (point , set ())
32
- seed = len (_data ) + len (pending_seeds )
33
- if seed in _data or seed in pending_seeds :
33
+ seed = len (_data ) + len (pending_seeds ) + len ( exclude )
34
+ if seed in _data or seed in pending_seeds or seed in exclude :
34
35
# Means that the seed already exists, for example
35
36
# when '_data[point].keys() | pending_points[point] == {0, 2}'.
36
37
# Only happens when starting the learner after cancelling/loading.
37
- return (set (range (seed )) - pending_seeds - _data .keys ()).pop ()
38
+ return (set (range (seed )) - pending_seeds - _data .keys () - exclude ).pop ()
38
39
return seed
39
40
40
- def loss_per_existing_point (self ):
41
- scale = self .value_scale ()
42
- points = []
43
- loss_improvements = []
44
- for p , sem in self .data_sem .items ():
45
- points .append ((p , self ._next_seed (p )))
46
- N = self .n_values (p )
47
- sem_improvement = (1 - sqrt (N - 1 ) / sqrt (N )) * sem
48
- loss_improvement = self .weight * sem_improvement / scale
49
- loss_improvements .append (loss_improvement )
50
- loss_improvements = self ._normalize_existing_points_loss_improvements (
51
- points , loss_improvements )
52
- return points , loss_improvements
53
-
54
41
def _add_to_pending (self , point ):
55
42
x , seed = self .unpack_point (point )
56
43
if x not in self .pending_points :
@@ -74,18 +61,35 @@ def _add_to_data(self, point, value):
74
61
def ask (self , n , tell_pending = True ):
75
62
"""Return n points that are expected to maximally reduce the loss."""
76
63
points , loss_improvements = [], []
77
- self ._fill_seed_stack (till = n )
78
64
79
65
# Take from the _seed_stack if there are any points.
66
+ self ._fill_seed_stack (till = n )
80
67
for i in range (n ):
81
- point , loss_improvement = self ._seed_stack [i ]
82
- points .append (point )
83
- loss_improvements .append (loss_improvement )
68
+ exclude_seeds = set ()
69
+ (point , nseeds ), loss_improvement = self ._seed_stack [i ]
70
+ for j in range (nseeds ):
71
+ seed = self ._next_seed (point , exclude_seeds )
72
+ exclude_seeds .add (seed )
73
+ points .append ((point , seed ))
74
+ loss_improvements .append (loss_improvement / nseeds )
75
+ if len (points ) >= n :
76
+ break
77
+ if len (points ) >= n :
78
+ break
84
79
85
80
if tell_pending :
86
81
for p in points :
87
82
self .tell_pending (p )
88
- self ._seed_stack = self ._seed_stack [n :]
83
+ nseeds_left = nseeds - j - 1 # of self._seed_stack[i]
84
+ if nseeds_left > 0 : # not all seeds have been asked
85
+ (point , nseeds ), loss_improvement = self ._seed_stack [i ]
86
+ self ._seed_stack [i ] = (
87
+ (point , nseeds_left ),
88
+ loss_improvement * nseeds_left / nseeds
89
+ )
90
+ self ._seed_stack = self ._seed_stack [i :]
91
+ else :
92
+ self ._seed_stack = self ._seed_stack [i + 1 :]
89
93
90
94
return points , loss_improvements
91
95
@@ -94,23 +98,29 @@ def _fill_seed_stack(self, till):
94
98
if n < 1 :
95
99
return
96
100
points , loss_improvements = [], []
97
- new_points , new_points_loss_improvements = (
98
- self ._ask_points_without_adding (n ))
99
- loss_improvements += self ._normalize_new_points_loss_improvements (
100
- new_points , new_points_loss_improvements )
101
+
102
+ new_points , new_points_loss_improvements = \
103
+ self .loss_per_new_point (n )
104
+
105
+ loss_improvements += new_points_loss_improvements
101
106
points += new_points
102
107
103
108
existing_points , existing_points_loss_improvements = \
104
109
self .loss_per_existing_point ()
110
+
105
111
points += existing_points
106
112
loss_improvements += existing_points_loss_improvements
107
113
108
114
loss_improvements , points = zip (* sorted (
109
115
zip (loss_improvements , points ), reverse = True ))
110
116
111
- points = list (points )[:n ]
112
- loss_improvements = list (loss_improvements )[:n ]
113
- self ._seed_stack += list (zip (points , loss_improvements ))
117
+ n_left = n
118
+ for loss_improvement , (point , nseeds ) in zip (
119
+ loss_improvements , points ):
120
+ self ._seed_stack .append (((point , nseeds ), loss_improvement ))
121
+ n_left -= nseeds
122
+ if n_left <= 0 :
123
+ break
114
124
115
125
def n_values (self , point ):
116
126
pending_points = self .pending_points .get (point , [])
@@ -121,40 +131,53 @@ def _mean_values_per_neighbor(self, neighbors):
121
131
return {p : sum (self .n_values (n ) for n in ns ) / len (ns )
122
132
for p , ns in neighbors .items ()}
123
133
124
- def _normalize_new_points_loss_improvements (self , points , loss_improvements ):
125
- """If we are suggesting a new (not yet suggested) point, then its
126
- 'loss_improvement' should be divided by the average number of values
127
- of its neigbors.
128
-
129
- This is because it will take a similar amount of points to reach
130
- that loss. """
134
+ def loss_per_new_point (self , n ):
135
+ """Add new points with at least self.min_values_per_point points
136
+ or with as many points as the neighbors have on average."""
137
+ points , loss_improvements = self ._ask_points_without_adding (n )
131
138
if len (self ._data ) < 4 :
132
- return loss_improvements
139
+ points = [(p , self .min_values_per_point ) for p , s in points ]
140
+ return points , loss_improvements
133
141
134
- only_points = [p for p , s in points ]
142
+ only_points = [p for p , s in points ] # points are [(x, seed), ...]
135
143
neighbors = self ._get_neighbor_mapping_new_points (only_points )
136
144
mean_values_per_neighbor = self ._mean_values_per_neighbor (neighbors )
137
145
138
- return [loss / mean_values_per_neighbor [p ]
139
- for (p , seed ), loss in zip (points , loss_improvements )]
146
+ points = []
147
+ for p in only_points :
148
+ n_neighbors = int (mean_values_per_neighbor [p ])
149
+ nseeds = max (n_neighbors , self .min_values_per_point )
150
+ points .append ((p , nseeds ))
140
151
141
- def _normalize_existing_points_loss_improvements (self , points , loss_improvements ):
142
- """If the neighbors of 'point' have twice as much values
143
- on average, then that 'point' should have an infinite loss.
152
+ return points , loss_improvements
144
153
145
- We do this because this point probably has a incorrect
146
- estimate of the sem."""
147
- if len (self ._data ) < 4 :
148
- return loss_improvements
154
+ def loss_per_existing_point (self ):
155
+ """Increase the number of seeds by 10%."""
156
+ if len (self .data ) < 4 :
157
+ return [], []
158
+ scale = self .value_scale ()
159
+ points = []
160
+ loss_improvements = []
149
161
150
162
neighbors = self ._get_neighbor_mapping_existing_points ()
151
163
mean_values_per_neighbor = self ._mean_values_per_neighbor (neighbors )
152
164
153
- def needs_more_data (p ):
154
- return mean_values_per_neighbor [p ] > 1.5 * self .n_values (p )
155
-
156
- return [inf if needs_more_data (p ) else loss
157
- for (p , seed ), loss in zip (points , loss_improvements )]
165
+ for p , sem in self .data_sem .items ():
166
+ n_neighbors = mean_values_per_neighbor [p ]
167
+ N = self .n_values (p )
168
+ n_more = int (0.1 * N ) # increase the amount of points by 10%
169
+ n_more = max (n_more , 1 ) # at least 1 point
170
+ points .append ((p , n_more ))
171
+ needs_more_data = n_neighbors > 1.5 * N
172
+ if needs_more_data :
173
+ loss_improvement = inf
174
+ else :
175
+ # This is the improvement considering we will add
176
+ # n_more seeds to the stack.
177
+ sem_improvement = (1 / sqrt (N ) - 1 / sqrt (N + n_more )) * sem
178
+ loss_improvement = self .weight * sem_improvement / scale # XXX: Do I need to divide by the scale?
179
+ loss_improvements .append (loss_improvement )
180
+ return points , loss_improvements
158
181
159
182
def _get_data (self ):
160
183
# change DataPoint -> dict for saving
@@ -165,9 +188,7 @@ def add_average_mixin(cls):
165
188
names = ('data' , 'data_sem' , 'mean_values_per_point' ,
166
189
'_next_seed' , 'loss_per_existing_point' , '_add_to_pending' ,
167
190
'_remove_from_to_pending' , '_add_to_data' , 'ask' , 'n_values' ,
168
- '_normalize_new_points_loss_improvements' ,
169
- '_normalize_existing_points_loss_improvements' ,
170
- '_mean_values_per_neighbor' ,
191
+ 'loss_per_new_point' , '_mean_values_per_neighbor' ,
171
192
'_get_data' , '_fill_seed_stack' )
172
193
173
194
for name in names :
0 commit comments