Skip to content

Commit 05be948

Browse files
committed
Use loss_goal and npoints_goal
1 parent 7900b5a commit 05be948

26 files changed

+153
-119
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def peak(x, a=0.01):
7575

7676

7777
learner = Learner1D(peak, bounds=(-1, 1))
78-
runner = Runner(learner, goal=0.01)
78+
runner = Runner(learner, loss_goal=0.01)
7979
runner.live_info()
8080
runner.live_plot()
8181
```

adaptive/runner.py

+58-17
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,19 @@
1414
import warnings
1515
from contextlib import suppress
1616
from datetime import datetime, timedelta
17-
from typing import Any, Callable
17+
from typing import Any, Callable, Union
1818

1919
import loky
2020

2121
from adaptive import BalancingLearner, BaseLearner, IntegratorLearner, SequenceLearner
2222
from adaptive.notebook_integration import in_ipynb, live_info, live_plot
2323

24+
try:
25+
from typing import TypeAlias
26+
except ModuleNotFoundError:
27+
# Python <3.10
28+
from typing_extensions import TypeAlias
29+
2430
try:
2531
import ipyparallel
2632

@@ -60,6 +66,10 @@
6066
# and https://github.com/python-adaptive/adaptive/issues/301
6167
_default_executor = loky.get_reusable_executor
6268

69+
GoalTypes: TypeAlias = Union[
70+
Callable[[BaseLearner], bool], int, float, datetime, timedelta, None
71+
]
72+
6373

6474
class BaseRunner(metaclass=abc.ABCMeta):
6575
r"""Base class for runners that use `concurrent.futures.Executors`.
@@ -120,8 +130,10 @@ class BaseRunner(metaclass=abc.ABCMeta):
120130
def __init__(
121131
self,
122132
learner,
123-
goal,
124133
*,
134+
goal: GoalTypes = None,
135+
loss_goal: float | None = None,
136+
npoints_goal: int | None = None,
125137
executor=None,
126138
ntasks=None,
127139
log=False,
@@ -132,7 +144,7 @@ def __init__(
132144
):
133145

134146
self.executor = _ensure_executor(executor)
135-
self.goal = auto_goal(goal, learner, allow_running_forever)
147+
self.goal = _goal(learner, goal, loss_goal, npoints_goal, allow_running_forever)
136148

137149
self._max_tasks = ntasks
138150

@@ -376,8 +388,10 @@ class BlockingRunner(BaseRunner):
376388
def __init__(
377389
self,
378390
learner,
379-
goal,
380391
*,
392+
goal: GoalTypes = None,
393+
loss_goal: float | None = None,
394+
npoints_goal: int | None = None,
381395
executor=None,
382396
ntasks=None,
383397
log=False,
@@ -389,7 +403,9 @@ def __init__(
389403
raise ValueError("Coroutine functions can only be used with 'AsyncRunner'.")
390404
super().__init__(
391405
learner,
392-
goal,
406+
goal=goal,
407+
loss_goal=loss_goal,
408+
npoints_goal=npoints_goal,
393409
executor=executor,
394410
ntasks=ntasks,
395411
log=log,
@@ -508,8 +524,10 @@ class AsyncRunner(BaseRunner):
508524
def __init__(
509525
self,
510526
learner,
511-
goal=None,
512527
*,
528+
goal: GoalTypes = None,
529+
loss_goal: float | None = None,
530+
npoints_goal: int | None = None,
513531
executor=None,
514532
ntasks=None,
515533
log=False,
@@ -537,7 +555,9 @@ def __init__(
537555

538556
super().__init__(
539557
learner,
540-
goal,
558+
goal=goal,
559+
loss_goal=loss_goal,
560+
npoints_goal=npoints_goal,
541561
executor=executor,
542562
ntasks=ntasks,
543563
log=log,
@@ -717,7 +737,13 @@ async def _saver():
717737
Runner = AsyncRunner
718738

719739

720-
def simple(learner, goal):
740+
def simple(
741+
learner,
742+
*,
743+
goal: GoalTypes = None,
744+
loss_goal: float | None = None,
745+
npoints_goal: int | None = None,
746+
):
721747
"""Run the learner until the goal is reached.
722748
723749
Requests a single point from the learner, evaluates
@@ -736,7 +762,7 @@ def simple(learner, goal):
736762
The end condition for the calculation. This function must take the
737763
learner as its sole argument, and return True if we should stop.
738764
"""
739-
goal = auto_goal(goal, learner)
765+
goal = _goal(learner, goal, loss_goal, npoints_goal, allow_running_forever=False)
740766
while not goal(learner):
741767
xs, _ = learner.ask(1)
742768
for x in xs:
@@ -871,14 +897,13 @@ def __call__(self, _):
871897
if self.start_time is None:
872898
self.start_time = datetime.now()
873899
return datetime.now() - self.start_time > self.dt
874-
elif isinstance(self.dt, datetime):
900+
if isinstance(self.dt, datetime):
875901
return datetime.now() > self.dt
876-
else:
877-
raise TypeError(f"`dt={self.dt}` is not a datetime or timedelta.")
902+
raise TypeError(f"`dt={self.dt}` is not a datetime or timedelta.")
878903

879904

880905
def auto_goal(
881-
goal: Callable[[BaseLearner], bool] | int | float | datetime | timedelta | None,
906+
goal: GoalTypes,
882907
learner: BaseLearner,
883908
allow_running_forever: bool = True,
884909
):
@@ -935,12 +960,28 @@ def auto_goal(
935960
return SequenceLearner.done
936961
if isinstance(learner, IntegratorLearner):
937962
return IntegratorLearner.done
938-
warnings.warn("Goal is None which means the learners continue forever!")
939-
if allow_running_forever:
940-
return lambda _: False
941-
else:
963+
if not allow_running_forever:
942964
raise ValueError(
943965
"Goal is None which means the learners"
944966
" continue forever and this is not allowed."
945967
)
968+
warnings.warn("Goal is None which means the learners continue forever!")
969+
return lambda _: False
946970
raise ValueError("Cannot determine goal from {goal}.")
971+
972+
973+
def _goal(
974+
learner: BaseLearner,
975+
goal: GoalTypes,
976+
loss_goal: float | None,
977+
npoints_goal: int | None,
978+
allow_running_forever: bool,
979+
):
980+
# goal, loss_goal, npoints_goal are mutually exclusive, only one can be not None
981+
if goal is not None and (loss_goal is not None or npoints_goal is not None):
982+
raise ValueError("Either goal, loss_goal, or npoints_goal can be specified.")
983+
if loss_goal is not None:
984+
goal = float(loss_goal)
985+
if npoints_goal is not None:
986+
goal = int(npoints_goal)
987+
return auto_goal(goal, learner, allow_running_forever)

adaptive/tests/test_average_learner.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def constant_function(seed):
6161
learner = AverageLearner(
6262
constant_function, atol=0.01, rtol=0.01, min_npoints=min_npoints
6363
)
64-
simple(learner, 1.0)
64+
simple(learner, loss_goal=1.0)
6565
assert learner.npoints >= max(2, min_npoints)
6666

6767

adaptive/tests/test_learner1d.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ def test_equal(l1, l2):
298298
for function in [f, f_vec]:
299299
learner = Learner1D(function, bounds=(-1, 1))
300300
learner2 = Learner1D(function, bounds=(-1, 1))
301-
simple(learner, goal=200)
301+
simple(learner, npoints_goal=200)
302302
xs, ys = zip(*learner.data.items())
303303

304304
# Make the scale huge to no get a scale doubling
@@ -374,7 +374,7 @@ def f(x):
374374
loss = curvature_loss_function()
375375
assert loss.nth_neighbors == 1
376376
learner = Learner1D(f, (-1, 1), loss_per_interval=loss)
377-
simple(learner, goal=100)
377+
simple(learner, npoints_goal=100)
378378
assert learner.npoints >= 100
379379

380380

@@ -385,7 +385,7 @@ def f(x):
385385
loss = curvature_loss_function()
386386
assert loss.nth_neighbors == 1
387387
learner = Learner1D(f, (-1, 1), loss_per_interval=loss)
388-
simple(learner, goal=100)
388+
simple(learner, npoints_goal=100)
389389
assert learner.npoints >= 100
390390

391391

@@ -398,7 +398,7 @@ def f(x):
398398
return x + a**2 / (a**2 + x**2)
399399

400400
learner = Learner1D(f, bounds=(-1, 1))
401-
simple(learner, 100)
401+
simple(learner, npoints_goal=100)
402402

403403

404404
def test_inf_loss_with_missing_bounds():
@@ -408,6 +408,6 @@ def test_inf_loss_with_missing_bounds():
408408
loss_per_interval=curvature_loss_function(),
409409
)
410410
# must be done in parallel because otherwise the bounds will be evaluated first
411-
BlockingRunner(learner, goal=0.01)
411+
BlockingRunner(learner, loss_goal=0.01)
412412

413413
learner.npoints > 20

adaptive/tests/test_learnernd.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ def test_interior_vs_bbox_gives_same_result():
3333
hull = scipy.spatial.ConvexHull(control._bounds_points)
3434
learner = LearnerND(f, bounds=hull)
3535

36-
simple(control, goal=0.1)
37-
simple(learner, goal=0.1)
36+
simple(control, loss_goal=0.1)
37+
simple(learner, loss_goal=0.1)
3838

3939
assert learner.data == control.data
4040

@@ -47,4 +47,4 @@ def test_vector_return_with_a_flat_layer():
4747
h3 = lambda xy: np.array([0 * f(xy), g(xy)]) # noqa: E731
4848
for function in [h1, h2, h3]:
4949
learner = LearnerND(function, bounds=[(-1, 1), (-1, 1)])
50-
simple(learner, goal=0.1)
50+
simple(learner, loss_goal=0.1)

adaptive/tests/test_learners.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def goal():
103103
return get_goal(learner.learner)
104104
return get_goal(learner)
105105

106-
simple(learner, goal())
106+
simple(learner, goal=goal())
107107

108108

109109
# Library of functions and associated learners.

adaptive/tests/test_pickling.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def test_serialization_for(learner_type, learner_kwargs, serializer, f):
9494

9595
learner = learner_type(f, **learner_kwargs)
9696

97-
simple(learner, goal_1)
97+
simple(learner, goal=goal_1)
9898
learner_bytes = serializer.dumps(learner)
9999
loss = learner.loss()
100100
asked = learner.ask(10)
@@ -113,5 +113,5 @@ def test_serialization_for(learner_type, learner_kwargs, serializer, f):
113113
# load again to undo the ask
114114
learner_loaded = serializer.loads(learner_bytes)
115115

116-
simple(learner_loaded, goal_2)
116+
simple(learner_loaded, goal=goal_2)
117117
assert learner_loaded.npoints == 20

adaptive/tests/test_runner.py

+13-17
Original file line numberDiff line numberDiff line change
@@ -19,22 +19,18 @@
1919
OPERATING_SYSTEM = platform.system()
2020

2121

22-
def blocking_runner(learner, goal):
23-
BlockingRunner(learner, goal, executor=SequentialExecutor())
22+
def blocking_runner(learner, **kw):
23+
BlockingRunner(learner, executor=SequentialExecutor(), **kw)
2424

2525

26-
def async_runner(learner, goal):
27-
runner = AsyncRunner(learner, goal, executor=SequentialExecutor())
26+
def async_runner(learner, **kw):
27+
runner = AsyncRunner(learner, executor=SequentialExecutor(), **kw)
2828
asyncio.get_event_loop().run_until_complete(runner.task)
2929

3030

3131
runners = [simple, blocking_runner, async_runner]
3232

3333

34-
def trivial_goal(learner):
35-
return learner.npoints > 10
36-
37-
3834
@pytest.mark.parametrize("runner", runners)
3935
def test_simple(runner):
4036
"""Test that the runners actually run."""
@@ -43,7 +39,7 @@ def f(x):
4339
return x
4440

4541
learner = Learner1D(f, (-1, 1))
46-
runner(learner, 10)
42+
runner(learner, npoints_goal=10)
4743
assert len(learner.data) >= 10
4844

4945

@@ -57,15 +53,15 @@ def test_nonconforming_output(runner):
5753
def f(x):
5854
return [0]
5955

60-
runner(Learner2D(f, ((-1, 1), (-1, 1))), trivial_goal)
56+
runner(Learner2D(f, ((-1, 1), (-1, 1))), npoints_goal=10)
6157

6258

6359
def test_aync_def_function():
6460
async def f(x):
6561
return x
6662

6763
learner = Learner1D(f, (-1, 1))
68-
runner = AsyncRunner(learner, trivial_goal)
64+
runner = AsyncRunner(learner, npoints_goal=10)
6965
asyncio.get_event_loop().run_until_complete(runner.task)
7066

7167

@@ -88,15 +84,15 @@ def test_concurrent_futures_executor():
8884

8985
BlockingRunner(
9086
Learner1D(linear, (-1, 1)),
91-
trivial_goal,
87+
npoints_goal=10,
9288
executor=ProcessPoolExecutor(max_workers=1),
9389
)
9490

9591

9692
def test_stop_after_goal():
9793
seconds_to_wait = 0.2 # don't make this too large or the test will take ages
9894
start_time = time.time()
99-
BlockingRunner(Learner1D(linear, (-1, 1)), stop_after(seconds=seconds_to_wait))
95+
BlockingRunner(Learner1D(linear, (-1, 1)), goal=stop_after(seconds=seconds_to_wait))
10096
stop_time = time.time()
10197
assert stop_time - start_time > seconds_to_wait
10298

@@ -119,7 +115,7 @@ def test_ipyparallel_executor():
119115
child.expect("Engines appear to have started successfully", timeout=35)
120116
ipyparallel_executor = Client()
121117
learner = Learner1D(linear, (-1, 1))
122-
BlockingRunner(learner, trivial_goal, executor=ipyparallel_executor)
118+
BlockingRunner(learner, npoints_goal=10, executor=ipyparallel_executor)
123119

124120
assert learner.npoints > 0
125121

@@ -137,20 +133,20 @@ def test_distributed_executor():
137133

138134
learner = Learner1D(linear, (-1, 1))
139135
client = Client(n_workers=1)
140-
BlockingRunner(learner, trivial_goal, executor=client)
136+
BlockingRunner(learner, npoints_goal=10, executor=client)
141137
client.shutdown()
142138
assert learner.npoints > 0
143139

144140

145141
def test_loky_executor(loky_executor):
146142
learner = Learner1D(lambda x: x, (-1, 1))
147143
BlockingRunner(
148-
learner, trivial_goal, executor=loky_executor, shutdown_executor=True
144+
learner, npoints_goal=10, executor=loky_executor, shutdown_executor=True
149145
)
150146
assert learner.npoints > 0
151147

152148

153149
def test_default_executor():
154150
learner = Learner1D(linear, (-1, 1))
155-
runner = AsyncRunner(learner, goal=10)
151+
runner = AsyncRunner(learner, npoints_goal=10)
156152
asyncio.get_event_loop().run_until_complete(runner.task)

adaptive/tests/test_sequence_learner.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@ def test_fail_with_sequence_of_unhashable():
1919
# https://github.com/python-adaptive/adaptive/issues/265
2020
seq = [{1: 1}] # unhashable
2121
learner = SequenceLearner(FailOnce(), sequence=seq)
22-
runner = Runner(
23-
learner, goal=SequenceLearner.done, retries=1, executor=SequentialExecutor()
24-
)
22+
runner = Runner(learner, retries=1, executor=SequentialExecutor())
2523
asyncio.get_event_loop().run_until_complete(runner.task)
2624
assert runner.status() == "finished"

0 commit comments

Comments
 (0)