Skip to content

Commit ae94904

Browse files
committed
revert the behaviour of previous commit, and scale point_loss with log(npoints)
1 parent 8b448a5 commit ae94904

File tree

3 files changed

+13
-45
lines changed

3 files changed

+13
-45
lines changed

adaptive/learner/average1D.py

-11
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,6 @@ def _get_neighbor_mapping_existing_points(self):
3737
return {k: [x for x in v if x is not None]
3838
for k, v in self.neighbors.items()}
3939

40-
def loss_per_point(self):
41-
loss_per_point = {}
42-
for p in self.data.keys():
43-
losses = []
44-
for neighbor in self.neighbors[p]:
45-
if neighbor is not None:
46-
ival = tuple(sorted((p, neighbor)))
47-
losses.append(self.losses[ival])
48-
loss_per_point[p] = sum(losses) / len(losses)
49-
return loss_per_point
50-
5140
def unpack_point(self, x_seed):
5241
return x_seed
5342

adaptive/learner/average2D.py

-11
Original file line numberDiff line numberDiff line change
@@ -103,17 +103,6 @@ def inside_bounds(self, xy_seed):
103103
xy, seed = self.unpack_point(xy_seed)
104104
return super().inside_bounds(xy)
105105

106-
def loss_per_point(self):
107-
ip = self.ip()
108-
losses = self.loss_per_triangle(ip)
109-
loss_per_point = defaultdict(list)
110-
points = list(self.data.keys())
111-
for simplex, loss in zip(ip.tri.vertices, losses):
112-
for i in simplex:
113-
loss_per_point[points[i]].append(loss)
114-
loss_per_point = {p: sum(losses) / len(losses) for p, losses in loss_per_point.items()}
115-
return loss_per_point
116-
117106
def _ensure_point(self, point):
118107
"""Adding a point with seed = 0.
119108
This used in '_fill_stack' in the Learner2D."""

adaptive/learner/average_mixin.py

+13-23
Original file line numberDiff line numberDiff line change
@@ -120,23 +120,8 @@ def _fill_seed_stack(self, till):
120120
if n <= 0:
121121
return
122122

123-
new_points, new_interval_losses, interval_neighbors = self._interval_losses(n)
124-
existing_points, existing_points_sem_losses, point_neighbors = self._point_losses()
125-
assert not interval_neighbors.keys() & point_neighbors.keys()
126-
127-
neighbors = {**interval_neighbors, **point_neighbors}
128-
129-
def normalize(points, losses, other_losses):
130-
for i, ((point, _), loss_improvement) in enumerate(zip(points, losses)):
131-
loss_other = sum(other_losses[p] for p in neighbors[point]) / len(neighbors[point])
132-
normalized_loss = loss_improvement + sqrt(loss_improvement * loss_other)
133-
losses[i] = min(normalized_loss, inf)
134-
135-
if neighbors:
136-
sem_losses = self.data_sem
137-
interval_losses = self.loss_per_point()
138-
normalize(new_points, new_interval_losses, sem_losses)
139-
normalize(existing_points, existing_points_sem_losses, interval_losses)
123+
new_points, new_interval_losses = self._interval_losses(n)
124+
existing_points, existing_points_sem_losses = self._point_losses()
140125

141126
points = new_points + existing_points
142127
loss_improvements = new_interval_losses + existing_points_sem_losses
@@ -172,7 +157,7 @@ def _interval_losses(self, n):
172157
points, loss_improvements = self._ask_points_without_adding(n)
173158
if len(self._data) < 4: # ANTON: fix (4) to bounds
174159
points = [(p, self.min_seeds_per_point) for p, s in points]
175-
return points, loss_improvements, {}
160+
return points, loss_improvements
176161

177162
only_points = [p for p, s in points] # points are [(x, seed), ...]
178163
neighbors = self._get_neighbor_mapping_new_points(only_points)
@@ -184,22 +169,25 @@ def _interval_losses(self, n):
184169
nseeds = max(n_neighbors, self.min_seeds_per_point)
185170
points.append((p, nseeds))
186171

187-
return points, loss_improvements, neighbors
172+
return points, loss_improvements
188173

189174
def _point_losses(self, fraction=1):
190-
"""Double the number of seeds."""
175+
"""Increase the number of seeds by 'fraction'."""
191176
if len(self.data) < 4:
192-
return [], [], {}
177+
return [], []
193178
scale = self.value_scale()
194179
points = []
195180
loss_improvements = []
196181

197182
neighbors = self._get_neighbor_mapping_existing_points()
198183
mean_seeds_per_neighbor = self._mean_seeds_per_neighbor(neighbors)
199184

185+
npoints_factor = np.log2(self.npoints)
186+
200187
for p, sem in self.data_sem.items():
201188
N = self.n_values(p)
202-
n_more = self.n_values(p) # double the amount of points
189+
n_more = int(fraction * N) # increase the amount of points by fraction
190+
n_more = max(n_more, 1) # at least 1 point
203191
points.append((p, n_more))
204192
needs_more_data = mean_seeds_per_neighbor[p] > 1.5 * N
205193
if needs_more_data:
@@ -211,8 +199,10 @@ def _point_losses(self, fraction=1):
211199
# We scale the values, sem(ys) / scale == sem(ys / scale).
212200
# and multiply them by a weight average_priority.
213201
loss_improvement = self.average_priority * sem_improvement / scale
202+
if loss_improvement < inf:
203+
loss_improvement *= npoints_factor
214204
loss_improvements.append(loss_improvement)
215-
return points, loss_improvements, neighbors
205+
return points, loss_improvements
216206

217207
def _get_data(self):
218208
# change DataPoint -> dict for saving

0 commit comments

Comments
 (0)