Skip to content

Commit 51f4292

Browse files
committed
add comments and make methods private
1 parent d0dab50 commit 51f4292

File tree

1 file changed

+15
-8
lines changed

1 file changed

+15
-8
lines changed

adaptive/learner/average_mixin.py

+15-8
Original file line numberDiff line numberDiff line change
@@ -26,23 +26,23 @@ def data_sem(self):
2626
def mean_values_per_point(self):
2727
return np.mean([x.n for x in self._data.values()])
2828

29-
def get_seed(self, point):
29+
def _next_seed(self, point):
3030
_data = self._data.get(point, {})
3131
pending_seeds = self.pending_points.get(point, set())
3232
seed = len(_data) + len(pending_seeds)
3333
if seed in _data or seed in pending_seeds:
34-
# means that the seed already exists, for example
34+
# Means that the seed already exists, for example
3535
# when '_data[point].keys() | pending_points[point] == {0, 2}'.
36+
# Only happens when starting the learner after cancelling/loading.
3637
return (set(range(seed)) - pending_seeds - _data.keys()).pop()
3738
return seed
3839

3940
def loss_per_existing_point(self):
4041
scale = self.value_scale()
41-
4242
points = []
4343
loss_improvements = []
4444
for p, sem in self.data_sem.items():
45-
points.append((p, self.get_seed(p)))
45+
points.append((p, self._next_seed(p)))
4646
N = self.n_values(p)
4747
sem_improvement = (1 - sqrt(N - 1) / sqrt(N)) * sem
4848
loss_improvement = self.weight * sem_improvement / scale
@@ -102,8 +102,12 @@ def _mean_values_per_neighbor(self, neighbors):
102102
for p, ns in neighbors.items()}
103103

104104
def _normalize_new_points_loss_improvements(self, points, loss_improvements):
105-
"""If we are suggesting a new point, then its 'loss_improvement' should
106-
be divided by the average number of values of its neigbors."""
105+
"""If we are suggesting a new (not yet suggested) point, then its
106+
'loss_improvement' should be divided by the average number of values
107+
of its neigbors.
108+
109+
This is because it will take a similar amount of points to reach
110+
that loss. """
107111
if len(self._data) < 4:
108112
return loss_improvements
109113

@@ -116,7 +120,10 @@ def _normalize_new_points_loss_improvements(self, points, loss_improvements):
116120

117121
def _normalize_existing_points_loss_improvements(self, points, loss_improvements):
118122
"""If the neighbors of 'point' have twice as much values
119-
on average, then that 'point' should have an infinite loss."""
123+
on average, then that 'point' should have an infinite loss.
124+
125+
We do this because this point probably has a incorrect
126+
estimate of the sem."""
120127
if len(self._data) < 4:
121128
return loss_improvements
122129

@@ -136,7 +143,7 @@ def _get_data(self):
136143

137144
def add_average_mixin(cls):
138145
names = ('data', 'data_sem', 'mean_values_per_point',
139-
'get_seed', 'loss_per_existing_point', '_add_to_pending',
146+
'_next_seed', 'loss_per_existing_point', '_add_to_pending',
140147
'_remove_from_to_pending', '_add_to_data', 'ask', 'n_values',
141148
'_normalize_new_points_loss_improvements',
142149
'_normalize_existing_points_loss_improvements',

0 commit comments

Comments
 (0)