diff --git a/adaptive/learner/sequence_learner.py b/adaptive/learner/sequence_learner.py index c7398dfa4..99be75e79 100644 --- a/adaptive/learner/sequence_learner.py +++ b/adaptive/learner/sequence_learner.py @@ -5,28 +5,15 @@ from adaptive.learner.base_learner import BaseLearner -class _IgnoreFirstArgument: - """Remove the first argument from the call signature. +class _CallFromSequence: + """Call function with index of sequence.""" - The SequenceLearner's function receives a tuple ``(index, point)`` - but the original function only takes ``point``. - - This is the same as `lambda x: function(x[1])`, however, that is not - pickable. - """ - - def __init__(self, function): + def __init__(self, function, sequence): self.function = function + self.sequence = sequence - def __call__(self, index_point, *args, **kwargs): - index, point = index_point - return self.function(point, *args, **kwargs) - - def __getstate__(self): - return self.function - - def __setstate__(self, function): - self.__init__(function) + def __call__(self, index, *args, **kwargs): + return self.function(self.sequence[index], *args, **kwargs) class SequenceLearner(BaseLearner): @@ -40,7 +27,7 @@ class SequenceLearner(BaseLearner): Parameters ---------- function : callable - The function to learn. Must take a single element `sequence`. + The function to learn. Must take a single element of `sequence`. sequence : sequence The sequence to learn. @@ -58,7 +45,7 @@ class SequenceLearner(BaseLearner): def __init__(self, function, sequence): self._original_function = function - self.function = _IgnoreFirstArgument(function) + self.function = _CallFromSequence(function, sequence) self._to_do_indices = SortedSet({i for i, _ in enumerate(sequence)}) self._ntotal = len(sequence) self.sequence = copy(sequence) @@ -67,21 +54,18 @@ def __init__(self, function, sequence): def ask(self, n, tell_pending=True): indices = [] - points = [] loss_improvements = [] for index in self._to_do_indices: - if len(points) >= n: + if len(indices) >= n: break - point = self.sequence[index] indices.append(index) - points.append((index, point)) loss_improvements.append(1 / self._ntotal) if tell_pending: - for i, p in zip(indices, points): - self.tell_pending((i, p)) + for index in indices: + self.tell_pending(index) - return points, loss_improvements + return indices, loss_improvements def _get_data(self): return self.data @@ -89,9 +73,7 @@ def _get_data(self): def _set_data(self, data): if data: indices, values = zip(*data.items()) - # the points aren't used by tell, so we can safely pass None - points = [(i, None) for i in indices] - self.tell_many(points, values) + self.tell_many(indices, values) def loss(self, real=True): if not (self._to_do_indices or self.pending_points): @@ -105,14 +87,12 @@ def remove_unfinished(self): self._to_do_indices.add(i) self.pending_points = set() - def tell(self, point, value): - index, point = point + def tell(self, index, value): self.data[index] = value self.pending_points.discard(index) self._to_do_indices.discard(index) - def tell_pending(self, point): - index, point = point + def tell_pending(self, index): self.pending_points.add(index) self._to_do_indices.discard(index) diff --git a/adaptive/tests/test_learners.py b/adaptive/tests/test_learners.py index 6edee2b30..dc7f11eb2 100644 --- a/adaptive/tests/test_learners.py +++ b/adaptive/tests/test_learners.py @@ -281,19 +281,9 @@ def test_adding_existing_data_is_idempotent(learner_type, f, learner_kwargs): M = random.randint(10, 30) pls = zip(*learner.ask(M)) cpls = zip(*control.ask(M)) - if learner_type is SequenceLearner: - # The SequenceLearner's points might not be hasable - points, values = zip(*pls) - indices, points = zip(*points) - cpoints, cvalues = zip(*cpls) - cindices, cpoints = zip(*cpoints) - assert (np.array(points) == np.array(cpoints)).all() - assert values == cvalues - assert indices == cindices - else: - # Point ordering is not defined, so compare as sets - assert set(pls) == set(cpls) + # Point ordering is not defined, so compare as sets + assert set(pls) == set(cpls) # XXX: This *should* pass (https://github.com/python-adaptive/adaptive/issues/55) @@ -324,20 +314,9 @@ def test_adding_non_chosen_data(learner_type, f, learner_kwargs): pls = zip(*learner.ask(M)) cpls = zip(*control.ask(M)) - if learner_type is SequenceLearner: - # The SequenceLearner's points might not be hasable - points, values = zip(*pls) - indices, points = zip(*points) - - cpoints, cvalues = zip(*cpls) - cindices, cpoints = zip(*cpoints) - assert (np.array(points) == np.array(cpoints)).all() - assert values == cvalues - assert indices == cindices - else: - # Point ordering within a single call to 'ask' - # is not guaranteed to be the same by the API. - assert set(pls) == set(cpls) + # Point ordering within a single call to 'ask' + # is not guaranteed to be the same by the API. + assert set(pls) == set(cpls) @run_with(Learner1D, xfail(Learner2D), xfail(LearnerND), AverageLearner) diff --git a/adaptive/tests/test_sequence_learner.py b/adaptive/tests/test_sequence_learner.py new file mode 100644 index 000000000..68ca956ca --- /dev/null +++ b/adaptive/tests/test_sequence_learner.py @@ -0,0 +1,26 @@ +import asyncio + +from adaptive import Runner, SequenceLearner +from adaptive.runner import SequentialExecutor + + +class FailOnce: + def __init__(self): + self.failed = False + + def __call__(self, value): + if self.failed: + return value + self.failed = True + raise RuntimeError + + +def test_fail_with_sequence_of_unhashable(): + # https://github.com/python-adaptive/adaptive/issues/265 + seq = [{1: 1}] # unhashable + learner = SequenceLearner(FailOnce(), sequence=seq) + runner = Runner( + learner, goal=SequenceLearner.done, retries=1, executor=SequentialExecutor() + ) + asyncio.get_event_loop().run_until_complete(runner.task) + assert runner.status() == "finished"