@@ -27,7 +27,7 @@ def data(self):
27
27
28
28
@property
29
29
def data_sem (self ):
30
- return {k : v .standard_error if v .n >= self .min_values_per_point else inf
30
+ return {k : v .standard_error if v .n >= self .min_seeds_per_point else inf
31
31
for k , v in self ._data .items ()}
32
32
33
33
def mean_seeds_per_point (self ):
@@ -93,6 +93,7 @@ def ask(self, n, tell_pending=True):
93
93
if not remaining :
94
94
break
95
95
96
+ # change from dict to list
96
97
points = [(point , seed ) for point , seeds in points .items ()
97
98
for seed in seeds ]
98
99
loss_improvements = [loss_improvements [point ] for point in points ]
@@ -119,14 +120,31 @@ def _fill_seed_stack(self, till):
119
120
if n <= 0 :
120
121
return
121
122
122
- new_points , new_points_losses = self ._interval_losses (n )
123
- existing_points , existing_points_losses = self ._point_losses ()
123
+ new_points , new_interval_losses , interval_neighbors = self ._interval_losses (n )
124
+ existing_points , existing_points_sem_losses , point_neighbors = self ._point_losses ()
125
+ assert not interval_neighbors .keys () & point_neighbors .keys ()
126
+
127
+ neighbors = {** interval_neighbors , ** point_neighbors }
128
+
129
+ def normalize (points , losses , other_losses ):
130
+ for i , ((point , _ ), loss_improvement ) in enumerate (zip (points , losses )):
131
+ loss_other = sum (other_losses [p ] for p in neighbors [point ]) / len (neighbors [point ])
132
+ normalized_loss = loss_improvement + sqrt (loss_improvement * loss_other )
133
+ losses [i ] = min (normalized_loss , inf )
134
+
135
+ if neighbors :
136
+ sem_losses = self .data_sem
137
+ interval_losses = self .loss_per_point ()
138
+ normalize (new_points , new_interval_losses , sem_losses )
139
+ normalize (existing_points , existing_points_sem_losses , interval_losses )
124
140
125
141
points = new_points + existing_points
126
- loss_improvements = new_points_losses + existing_points_losses
142
+ loss_improvements = new_interval_losses + existing_points_sem_losses
127
143
128
- loss_improvements , points = zip (* sorted (
129
- zip (loss_improvements , points ), reverse = True )) # ANTON: sort by (loss_improvement / nseeds)
144
+ priority = [(- loss / nseeds , loss , (point , nseeds ))
145
+ for loss , (point , nseeds ) in zip (loss_improvements , points )]
146
+
147
+ _ , loss_improvements , points = zip (* sorted (priority ))
130
148
131
149
# Add points to the _seed_stack, it can happen that its
132
150
# length exceeds the number of requested points.
@@ -149,12 +167,12 @@ def _mean_seeds_per_neighbor(self, neighbors):
149
167
for p , ns in neighbors .items ()}
150
168
151
169
def _interval_losses (self , n ):
152
- """Add new points with at least self.min_values_per_point points
170
+ """Add new points with at least self.min_seeds_per_point points
153
171
or with as many points as the neighbors have on average."""
154
172
points , loss_improvements = self ._ask_points_without_adding (n )
155
173
if len (self ._data ) < 4 : # ANTON: fix (4) to bounds
156
- points = [(p , self .min_values_per_point ) for p , s in points ]
157
- return points , loss_improvements
174
+ points = [(p , self .min_seeds_per_point ) for p , s in points ]
175
+ return points , loss_improvements , {}
158
176
159
177
only_points = [p for p , s in points ] # points are [(x, seed), ...]
160
178
neighbors = self ._get_neighbor_mapping_new_points (only_points )
@@ -163,15 +181,15 @@ def _interval_losses(self, n):
163
181
points = []
164
182
for p in only_points :
165
183
n_neighbors = int (mean_seeds_per_neighbor [p ])
166
- nseeds = max (n_neighbors , self .min_values_per_point )
184
+ nseeds = max (n_neighbors , self .min_seeds_per_point )
167
185
points .append ((p , nseeds ))
168
186
169
- return points , loss_improvements
187
+ return points , loss_improvements , neighbors
170
188
171
189
def _point_losses (self , fraction = 1 ):
172
190
"""Double the number of seeds."""
173
191
if len (self .data ) < 4 :
174
- return [], []
192
+ return [], [], {}
175
193
scale = self .value_scale ()
176
194
points = []
177
195
loss_improvements = []
@@ -181,8 +199,7 @@ def _point_losses(self, fraction=1):
181
199
182
200
for p , sem in self .data_sem .items ():
183
201
N = self .n_values (p )
184
- n_more = int (fraction * N ) # double the amount of points
185
- n_more = max (n_more , 1 ) # at least 1 point
202
+ n_more = self .n_values (p ) # double the amount of points
186
203
points .append ((p , n_more ))
187
204
needs_more_data = mean_seeds_per_neighbor [p ] > 1.5 * N
188
205
if needs_more_data :
@@ -195,7 +212,7 @@ def _point_losses(self, fraction=1):
195
212
# and multiply them by a weight average_priority.
196
213
loss_improvement = self .average_priority * sem_improvement / scale
197
214
loss_improvements .append (loss_improvement )
198
- return points , loss_improvements
215
+ return points , loss_improvements , neighbors
199
216
200
217
def _get_data (self ):
201
218
# change DataPoint -> dict for saving
0 commit comments