Skip to content

Commit e824326

Browse files
committed
normalize the loss improvements by the std or interval_loss of the neighbors
1 parent 3928eeb commit e824326

File tree

3 files changed

+54
-15
lines changed

3 files changed

+54
-15
lines changed

adaptive/learner/average1D.py

+11
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,17 @@ def _get_neighbor_mapping_existing_points(self):
3737
return {k: [x for x in v if x is not None]
3838
for k, v in self.neighbors.items()}
3939

40+
def loss_per_point(self):
41+
loss_per_point = {}
42+
for p in self.data.keys():
43+
losses = []
44+
for neighbor in self.neighbors[p]:
45+
if neighbor is not None:
46+
ival = tuple(sorted((p, neighbor)))
47+
losses.append(self.losses[ival])
48+
loss_per_point[p] = sum(losses) / len(losses)
49+
return loss_per_point
50+
4051
def unpack_point(self, x_seed):
4152
return x_seed
4253

adaptive/learner/average2D.py

+11
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,17 @@ def inside_bounds(self, xy_seed):
103103
xy, seed = self.unpack_point(xy_seed)
104104
return super().inside_bounds(xy)
105105

106+
def loss_per_point(self):
107+
ip = self.ip()
108+
losses = self.loss_per_triangle(ip)
109+
loss_per_point = defaultdict(list)
110+
points = list(self.data.keys())
111+
for simplex, loss in zip(ip.tri.vertices, losses):
112+
for i in simplex:
113+
loss_per_point[points[i]].append(loss)
114+
loss_per_point = {p: sum(losses) / len(losses) for p, losses in loss_per_point.items()}
115+
return loss_per_point
116+
106117
def _ensure_point(self, point):
107118
"""Adding a point with seed = 0.
108119
This used in '_fill_stack' in the Learner2D."""

adaptive/learner/average_mixin.py

+32-15
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def data(self):
2727

2828
@property
2929
def data_sem(self):
30-
return {k: v.standard_error if v.n >= self.min_values_per_point else inf
30+
return {k: v.standard_error if v.n >= self.min_seeds_per_point else inf
3131
for k, v in self._data.items()}
3232

3333
def mean_seeds_per_point(self):
@@ -93,6 +93,7 @@ def ask(self, n, tell_pending=True):
9393
if not remaining:
9494
break
9595

96+
# change from dict to list
9697
points = [(point, seed) for point, seeds in points.items()
9798
for seed in seeds]
9899
loss_improvements = [loss_improvements[point] for point in points]
@@ -119,14 +120,31 @@ def _fill_seed_stack(self, till):
119120
if n <= 0:
120121
return
121122

122-
new_points, new_points_losses = self._interval_losses(n)
123-
existing_points, existing_points_losses = self._point_losses()
123+
new_points, new_interval_losses, interval_neighbors = self._interval_losses(n)
124+
existing_points, existing_points_sem_losses, point_neighbors = self._point_losses()
125+
assert not interval_neighbors.keys() & point_neighbors.keys()
126+
127+
neighbors = {**interval_neighbors, **point_neighbors}
128+
129+
def normalize(points, losses, other_losses):
130+
for i, ((point, _), loss_improvement) in enumerate(zip(points, losses)):
131+
loss_other = sum(other_losses[p] for p in neighbors[point]) / len(neighbors[point])
132+
normalized_loss = loss_improvement + sqrt(loss_improvement * loss_other)
133+
losses[i] = min(normalized_loss, inf)
134+
135+
if neighbors:
136+
sem_losses = self.data_sem
137+
interval_losses = self.loss_per_point()
138+
normalize(new_points, new_interval_losses, sem_losses)
139+
normalize(existing_points, existing_points_sem_losses, interval_losses)
124140

125141
points = new_points + existing_points
126-
loss_improvements = new_points_losses + existing_points_losses
142+
loss_improvements = new_interval_losses + existing_points_sem_losses
127143

128-
loss_improvements, points = zip(*sorted(
129-
zip(loss_improvements, points), reverse=True)) # ANTON: sort by (loss_improvement / nseeds)
144+
priority = [(-loss / nseeds, loss, (point, nseeds))
145+
for loss, (point, nseeds) in zip(loss_improvements, points)]
146+
147+
_, loss_improvements, points = zip(*sorted(priority))
130148

131149
# Add points to the _seed_stack, it can happen that its
132150
# length exceeds the number of requested points.
@@ -149,12 +167,12 @@ def _mean_seeds_per_neighbor(self, neighbors):
149167
for p, ns in neighbors.items()}
150168

151169
def _interval_losses(self, n):
152-
"""Add new points with at least self.min_values_per_point points
170+
"""Add new points with at least self.min_seeds_per_point points
153171
or with as many points as the neighbors have on average."""
154172
points, loss_improvements = self._ask_points_without_adding(n)
155173
if len(self._data) < 4: # ANTON: fix (4) to bounds
156-
points = [(p, self.min_values_per_point) for p, s in points]
157-
return points, loss_improvements
174+
points = [(p, self.min_seeds_per_point) for p, s in points]
175+
return points, loss_improvements, {}
158176

159177
only_points = [p for p, s in points] # points are [(x, seed), ...]
160178
neighbors = self._get_neighbor_mapping_new_points(only_points)
@@ -163,15 +181,15 @@ def _interval_losses(self, n):
163181
points = []
164182
for p in only_points:
165183
n_neighbors = int(mean_seeds_per_neighbor[p])
166-
nseeds = max(n_neighbors, self.min_values_per_point)
184+
nseeds = max(n_neighbors, self.min_seeds_per_point)
167185
points.append((p, nseeds))
168186

169-
return points, loss_improvements
187+
return points, loss_improvements, neighbors
170188

171189
def _point_losses(self, fraction=1):
172190
"""Double the number of seeds."""
173191
if len(self.data) < 4:
174-
return [], []
192+
return [], [], {}
175193
scale = self.value_scale()
176194
points = []
177195
loss_improvements = []
@@ -181,8 +199,7 @@ def _point_losses(self, fraction=1):
181199

182200
for p, sem in self.data_sem.items():
183201
N = self.n_values(p)
184-
n_more = int(fraction * N) # double the amount of points
185-
n_more = max(n_more, 1) # at least 1 point
202+
n_more = self.n_values(p) # double the amount of points
186203
points.append((p, n_more))
187204
needs_more_data = mean_seeds_per_neighbor[p] > 1.5 * N
188205
if needs_more_data:
@@ -195,7 +212,7 @@ def _point_losses(self, fraction=1):
195212
# and multiply them by a weight average_priority.
196213
loss_improvement = self.average_priority * sem_improvement / scale
197214
loss_improvements.append(loss_improvement)
198-
return points, loss_improvements
215+
return points, loss_improvements, neighbors
199216

200217
def _get_data(self):
201218
# change DataPoint -> dict for saving

0 commit comments

Comments
 (0)