@@ -120,23 +120,8 @@ def _fill_seed_stack(self, till):
120
120
if n <= 0 :
121
121
return
122
122
123
- new_points , new_interval_losses , interval_neighbors = self ._interval_losses (n )
124
- existing_points , existing_points_sem_losses , point_neighbors = self ._point_losses ()
125
- assert not interval_neighbors .keys () & point_neighbors .keys ()
126
-
127
- neighbors = {** interval_neighbors , ** point_neighbors }
128
-
129
- def normalize (points , losses , other_losses ):
130
- for i , ((point , _ ), loss_improvement ) in enumerate (zip (points , losses )):
131
- loss_other = sum (other_losses [p ] for p in neighbors [point ]) / len (neighbors [point ])
132
- normalized_loss = loss_improvement + sqrt (loss_improvement * loss_other )
133
- losses [i ] = min (normalized_loss , inf )
134
-
135
- if neighbors :
136
- sem_losses = self .data_sem
137
- interval_losses = self .loss_per_point ()
138
- normalize (new_points , new_interval_losses , sem_losses )
139
- normalize (existing_points , existing_points_sem_losses , interval_losses )
123
+ new_points , new_interval_losses = self ._interval_losses (n )
124
+ existing_points , existing_points_sem_losses = self ._point_losses ()
140
125
141
126
points = new_points + existing_points
142
127
loss_improvements = new_interval_losses + existing_points_sem_losses
@@ -172,7 +157,7 @@ def _interval_losses(self, n):
172
157
points , loss_improvements = self ._ask_points_without_adding (n )
173
158
if len (self ._data ) < 4 : # ANTON: fix (4) to bounds
174
159
points = [(p , self .min_seeds_per_point ) for p , s in points ]
175
- return points , loss_improvements , {}
160
+ return points , loss_improvements
176
161
177
162
only_points = [p for p , s in points ] # points are [(x, seed), ...]
178
163
neighbors = self ._get_neighbor_mapping_new_points (only_points )
@@ -184,22 +169,25 @@ def _interval_losses(self, n):
184
169
nseeds = max (n_neighbors , self .min_seeds_per_point )
185
170
points .append ((p , nseeds ))
186
171
187
- return points , loss_improvements , neighbors
172
+ return points , loss_improvements
188
173
189
174
def _point_losses (self , fraction = 1 ):
190
- """Double the number of seeds."""
175
+ """Increase the number of seeds by 'fraction' ."""
191
176
if len (self .data ) < 4 :
192
- return [], [], {}
177
+ return [], []
193
178
scale = self .value_scale ()
194
179
points = []
195
180
loss_improvements = []
196
181
197
182
neighbors = self ._get_neighbor_mapping_existing_points ()
198
183
mean_seeds_per_neighbor = self ._mean_seeds_per_neighbor (neighbors )
199
184
185
+ npoints_factor = np .log2 (self .npoints )
186
+
200
187
for p , sem in self .data_sem .items ():
201
188
N = self .n_values (p )
202
- n_more = self .n_values (p ) # double the amount of points
189
+ n_more = int (fraction * N ) # increase the amount of points by fraction
190
+ n_more = max (n_more , 1 ) # at least 1 point
203
191
points .append ((p , n_more ))
204
192
needs_more_data = mean_seeds_per_neighbor [p ] > 1.5 * N
205
193
if needs_more_data :
@@ -211,8 +199,10 @@ def _point_losses(self, fraction=1):
211
199
# We scale the values, sem(ys) / scale == sem(ys / scale).
212
200
# and multiply them by a weight average_priority.
213
201
loss_improvement = self .average_priority * sem_improvement / scale
202
+ if loss_improvement < inf :
203
+ loss_improvement *= npoints_factor
214
204
loss_improvements .append (loss_improvement )
215
- return points , loss_improvements , neighbors
205
+ return points , loss_improvements
216
206
217
207
def _get_data (self ):
218
208
# change DataPoint -> dict for saving
0 commit comments