Skip to content

Commit 17c9b79

Browse files
committed
Fill the seed stack by either increasing the # seed by 10% or by
adding new points with as many seeds as the neighbors have.
1 parent 2681aeb commit 17c9b79

File tree

1 file changed

+77
-56
lines changed

1 file changed

+77
-56
lines changed

adaptive/learner/average_mixin.py

+77-56
Original file line numberDiff line numberDiff line change
@@ -26,31 +26,18 @@ def data_sem(self):
2626
def mean_values_per_point(self):
2727
return np.mean([x.n for x in self._data.values()])
2828

29-
def _next_seed(self, point):
29+
def _next_seed(self, point, exclude=None):
30+
exclude = set(exclude) if exclude is not None else set()
3031
_data = self._data.get(point, {})
3132
pending_seeds = self.pending_points.get(point, set())
32-
seed = len(_data) + len(pending_seeds)
33-
if seed in _data or seed in pending_seeds:
33+
seed = len(_data) + len(pending_seeds) + len(exclude)
34+
if seed in _data or seed in pending_seeds or seed in exclude:
3435
# Means that the seed already exists, for example
3536
# when '_data[point].keys() | pending_points[point] == {0, 2}'.
3637
# Only happens when starting the learner after cancelling/loading.
37-
return (set(range(seed)) - pending_seeds - _data.keys()).pop()
38+
return (set(range(seed)) - pending_seeds - _data.keys() - exclude).pop()
3839
return seed
3940

40-
def loss_per_existing_point(self):
41-
scale = self.value_scale()
42-
points = []
43-
loss_improvements = []
44-
for p, sem in self.data_sem.items():
45-
points.append((p, self._next_seed(p)))
46-
N = self.n_values(p)
47-
sem_improvement = (1 - sqrt(N - 1) / sqrt(N)) * sem
48-
loss_improvement = self.weight * sem_improvement / scale
49-
loss_improvements.append(loss_improvement)
50-
loss_improvements = self._normalize_existing_points_loss_improvements(
51-
points, loss_improvements)
52-
return points, loss_improvements
53-
5441
def _add_to_pending(self, point):
5542
x, seed = self.unpack_point(point)
5643
if x not in self.pending_points:
@@ -74,18 +61,35 @@ def _add_to_data(self, point, value):
7461
def ask(self, n, tell_pending=True):
7562
"""Return n points that are expected to maximally reduce the loss."""
7663
points, loss_improvements = [], []
77-
self._fill_seed_stack(till=n)
7864

7965
# Take from the _seed_stack if there are any points.
66+
self._fill_seed_stack(till=n)
8067
for i in range(n):
81-
point, loss_improvement = self._seed_stack[i]
82-
points.append(point)
83-
loss_improvements.append(loss_improvement)
68+
exclude_seeds = set()
69+
(point, nseeds), loss_improvement = self._seed_stack[i]
70+
for j in range(nseeds):
71+
seed = self._next_seed(point, exclude_seeds)
72+
exclude_seeds.add(seed)
73+
points.append((point, seed))
74+
loss_improvements.append(loss_improvement / nseeds)
75+
if len(points) >= n:
76+
break
77+
if len(points) >= n:
78+
break
8479

8580
if tell_pending:
8681
for p in points:
8782
self.tell_pending(p)
88-
self._seed_stack = self._seed_stack[n:]
83+
nseeds_left = nseeds - j - 1 # of self._seed_stack[i]
84+
if nseeds_left > 0: # not all seeds have been asked
85+
(point, nseeds), loss_improvement = self._seed_stack[i]
86+
self._seed_stack[i] = (
87+
(point, nseeds_left),
88+
loss_improvement * nseeds_left / nseeds
89+
)
90+
self._seed_stack = self._seed_stack[i:]
91+
else:
92+
self._seed_stack = self._seed_stack[i+1:]
8993

9094
return points, loss_improvements
9195

@@ -94,23 +98,29 @@ def _fill_seed_stack(self, till):
9498
if n < 1:
9599
return
96100
points, loss_improvements = [], []
97-
new_points, new_points_loss_improvements = (
98-
self._ask_points_without_adding(n))
99-
loss_improvements += self._normalize_new_points_loss_improvements(
100-
new_points, new_points_loss_improvements)
101+
102+
new_points, new_points_loss_improvements = \
103+
self.loss_per_new_point(n)
104+
105+
loss_improvements += new_points_loss_improvements
101106
points += new_points
102107

103108
existing_points, existing_points_loss_improvements = \
104109
self.loss_per_existing_point()
110+
105111
points += existing_points
106112
loss_improvements += existing_points_loss_improvements
107113

108114
loss_improvements, points = zip(*sorted(
109115
zip(loss_improvements, points), reverse=True))
110116

111-
points = list(points)[:n]
112-
loss_improvements = list(loss_improvements)[:n]
113-
self._seed_stack += list(zip(points, loss_improvements))
117+
n_left = n
118+
for loss_improvement, (point, nseeds) in zip(
119+
loss_improvements, points):
120+
self._seed_stack.append(((point, nseeds), loss_improvement))
121+
n_left -= nseeds
122+
if n_left <= 0:
123+
break
114124

115125
def n_values(self, point):
116126
pending_points = self.pending_points.get(point, [])
@@ -121,40 +131,53 @@ def _mean_values_per_neighbor(self, neighbors):
121131
return {p: sum(self.n_values(n) for n in ns) / len(ns)
122132
for p, ns in neighbors.items()}
123133

124-
def _normalize_new_points_loss_improvements(self, points, loss_improvements):
125-
"""If we are suggesting a new (not yet suggested) point, then its
126-
'loss_improvement' should be divided by the average number of values
127-
of its neigbors.
128-
129-
This is because it will take a similar amount of points to reach
130-
that loss. """
134+
def loss_per_new_point(self, n):
135+
"""Add new points with at least self.min_values_per_point points
136+
or with as many points as the neighbors have on average."""
137+
points, loss_improvements = self._ask_points_without_adding(n)
131138
if len(self._data) < 4:
132-
return loss_improvements
139+
points = [(p, self.min_values_per_point) for p, s in points]
140+
return points, loss_improvements
133141

134-
only_points = [p for p, s in points]
142+
only_points = [p for p, s in points] # points are [(x, seed), ...]
135143
neighbors = self._get_neighbor_mapping_new_points(only_points)
136144
mean_values_per_neighbor = self._mean_values_per_neighbor(neighbors)
137145

138-
return [loss / mean_values_per_neighbor[p]
139-
for (p, seed), loss in zip(points, loss_improvements)]
146+
points = []
147+
for p in only_points:
148+
n_neighbors = int(mean_values_per_neighbor[p])
149+
nseeds = max(n_neighbors, self.min_values_per_point)
150+
points.append((p, nseeds))
140151

141-
def _normalize_existing_points_loss_improvements(self, points, loss_improvements):
142-
"""If the neighbors of 'point' have twice as much values
143-
on average, then that 'point' should have an infinite loss.
152+
return points, loss_improvements
144153

145-
We do this because this point probably has a incorrect
146-
estimate of the sem."""
147-
if len(self._data) < 4:
148-
return loss_improvements
154+
def loss_per_existing_point(self):
155+
"""Increase the number of seeds by 10%."""
156+
if len(self.data) < 4:
157+
return [], []
158+
scale = self.value_scale()
159+
points = []
160+
loss_improvements = []
149161

150162
neighbors = self._get_neighbor_mapping_existing_points()
151163
mean_values_per_neighbor = self._mean_values_per_neighbor(neighbors)
152164

153-
def needs_more_data(p):
154-
return mean_values_per_neighbor[p] > 1.5 * self.n_values(p)
155-
156-
return [inf if needs_more_data(p) else loss
157-
for (p, seed), loss in zip(points, loss_improvements)]
165+
for p, sem in self.data_sem.items():
166+
n_neighbors = mean_values_per_neighbor[p]
167+
N = self.n_values(p)
168+
n_more = int(0.1 * N) # increase the amount of points by 10%
169+
n_more = max(n_more, 1) # at least 1 point
170+
points.append((p, n_more))
171+
needs_more_data = n_neighbors > 1.5 * N
172+
if needs_more_data:
173+
loss_improvement = inf
174+
else:
175+
# This is the improvement considering we will add
176+
# n_more seeds to the stack.
177+
sem_improvement = (1 / sqrt(N) - 1 / sqrt(N + n_more)) * sem
178+
loss_improvement = self.weight * sem_improvement / scale # XXX: Do I need to divide by the scale?
179+
loss_improvements.append(loss_improvement)
180+
return points, loss_improvements
158181

159182
def _get_data(self):
160183
# change DataPoint -> dict for saving
@@ -165,9 +188,7 @@ def add_average_mixin(cls):
165188
names = ('data', 'data_sem', 'mean_values_per_point',
166189
'_next_seed', 'loss_per_existing_point', '_add_to_pending',
167190
'_remove_from_to_pending', '_add_to_data', 'ask', 'n_values',
168-
'_normalize_new_points_loss_improvements',
169-
'_normalize_existing_points_loss_improvements',
170-
'_mean_values_per_neighbor',
191+
'loss_per_new_point', '_mean_values_per_neighbor',
171192
'_get_data', '_fill_seed_stack')
172193

173194
for name in names:

0 commit comments

Comments
 (0)