Skip to content

Commit 02f76a3

Browse files
committed
implement explicitly
1 parent 11e14c4 commit 02f76a3

File tree

6 files changed

+135
-51
lines changed

6 files changed

+135
-51
lines changed

adaptive/runner.py

+113-37
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@
7272
# and https://github.com/python-adaptive/adaptive/issues/301
7373
_default_executor = loky.get_reusable_executor
7474

75-
GoalTypes: TypeAlias = Union[
75+
_GoalTypes: TypeAlias = Union[
7676
Callable[[BaseLearner], bool], int, float, datetime, timedelta, None
7777
]
7878

@@ -86,15 +86,22 @@ class BaseRunner(metaclass=abc.ABCMeta):
8686
goal : callable, optional
8787
The end condition for the calculation. This function must take
8888
the learner as its sole argument, and return True when we should
89-
stop requesting more points. (Advanced use) Instead of providing a
90-
function, see `auto_goal` for other types that are accepted here.
89+
stop requesting more points.
9190
loss_goal : float, optional
9291
Convenience argument, use instead of ``goal``. The end condition for the
9392
calculation. Stop when the loss is smaller than this value.
9493
npoints_goal : int, optional
9594
Convenience argument, use instead of ``goal``. The end condition for the
9695
calculation. Stop when the number of points is larger or
9796
equal than this value.
97+
datetime_goal : datetime, optional
98+
Convenience argument, use instead of ``goal``. The end condition for the
99+
calculation. Stop when the current time is larger or equal than this
100+
value.
101+
timedelta_goal : timedelta, optional
102+
Convenience argument, use instead of ``goal``. The end condition for the
103+
calculation. Stop when the current time is larger or equal than
104+
``start_time + timedelta_goal``.
98105
executor : `concurrent.futures.Executor`, `distributed.Client`,\
99106
`mpi4py.futures.MPIPoolExecutor`, `ipyparallel.Client` or\
100107
`loky.get_reusable_executor`, optional
@@ -144,10 +151,12 @@ class BaseRunner(metaclass=abc.ABCMeta):
144151
def __init__(
145152
self,
146153
learner,
147-
goal: GoalTypes = None,
154+
goal: Callable[[BaseLearner], bool] | None = None,
148155
*,
149156
loss_goal: float | None = None,
150157
npoints_goal: int | None = None,
158+
datetime_goal: datetime.datetime | None = None,
159+
timedelta_goal: datetime.timedelta | None = None,
151160
executor=None,
152161
ntasks=None,
153162
log=False,
@@ -158,7 +167,15 @@ def __init__(
158167
):
159168

160169
self.executor = _ensure_executor(executor)
161-
self.goal = _goal(learner, goal, loss_goal, npoints_goal, allow_running_forever)
170+
self.goal = _goal(
171+
learner,
172+
goal,
173+
loss_goal,
174+
npoints_goal,
175+
datetime_goal,
176+
timedelta_goal,
177+
allow_running_forever,
178+
)
162179

163180
self._max_tasks = ntasks
164181

@@ -348,8 +365,7 @@ class BlockingRunner(BaseRunner):
348365
goal : callable
349366
The end condition for the calculation. This function must take
350367
the learner as its sole argument, and return True when we should
351-
stop requesting more points. (Advanced use) Instead of providing a
352-
function, see `auto_goal` for other types that are accepted here.
368+
stop requesting more points.
353369
loss_goal : float
354370
Convenience argument, use instead of ``goal``. The end condition for the
355371
calculation. Stop when the loss is smaller than this value.
@@ -410,10 +426,12 @@ class BlockingRunner(BaseRunner):
410426
def __init__(
411427
self,
412428
learner,
413-
goal: GoalTypes = None,
429+
goal: Callable[[BaseLearner], bool] | None = None,
414430
*,
415431
loss_goal: float | None = None,
416432
npoints_goal: int | None = None,
433+
datetime_goal: datetime.datetime | None = None,
434+
timedelta_goal: datetime.timedelta | None = None,
417435
executor=None,
418436
ntasks=None,
419437
log=False,
@@ -481,8 +499,7 @@ class AsyncRunner(BaseRunner):
481499
goal : callable, optional
482500
The end condition for the calculation. This function must take
483501
the learner as its sole argument, and return True when we should
484-
stop requesting more points. (Advanced use) Instead of providing a
485-
function, see `auto_goal` for other types that are accepted here.
502+
stop requesting more points.
486503
If not provided, the runner will run forever (or stop when no more
487504
points can be added), or until ``self.task.cancel()`` is called.
488505
loss_goal : float, optional
@@ -492,6 +509,14 @@ class AsyncRunner(BaseRunner):
492509
Convenience argument, use instead of ``goal``. The end condition for the
493510
calculation. Stop when the number of points is larger or
494511
equal than this value.
512+
datetime_goal : datetime, optional
513+
Convenience argument, use instead of ``goal``. The end condition for the
514+
calculation. Stop when the current time is larger or equal than this
515+
value.
516+
timedelta_goal : timedelta, optional
517+
Convenience argument, use instead of ``goal``. The end condition for the
518+
calculation. Stop when the current time is larger or equal than
519+
``start_time + timedelta_goal``.
495520
executor : `concurrent.futures.Executor`, `distributed.Client`,\
496521
`mpi4py.futures.MPIPoolExecutor`, `ipyparallel.Client` or\
497522
`loky.get_reusable_executor`, optional
@@ -555,10 +580,12 @@ class AsyncRunner(BaseRunner):
555580
def __init__(
556581
self,
557582
learner,
558-
goal: GoalTypes = None,
583+
goal: Callable[[BaseLearner], bool] | None = None,
559584
*,
560585
loss_goal: float | None = None,
561586
npoints_goal: int | None = None,
587+
datetime_goal: datetime.datetime | None = None,
588+
timedelta_goal: datetime.timedelta | None = None,
562589
executor=None,
563590
ntasks=None,
564591
log=False,
@@ -770,10 +797,12 @@ async def _saver():
770797

771798
def simple(
772799
learner,
773-
goal: GoalTypes = None,
800+
goal: Callable[[BaseLearner], bool] | None = None,
774801
*,
775802
loss_goal: float | None = None,
776803
npoints_goal: int | None = None,
804+
datetime_goal: datetime.datetime | None = None,
805+
timedelta_goal: datetime.timedelta | None = None,
777806
):
778807
"""Run the learner until the goal is reached.
779808
@@ -800,8 +829,24 @@ def simple(
800829
Convenience argument, use instead of ``goal``. The end condition for the
801830
calculation. Stop when the number of points is larger or
802831
equal than this value.
832+
datetime_goal : datetime, optional
833+
Convenience argument, use instead of ``goal``. The end condition for the
834+
calculation. Stop when the current time is larger or equal than this
835+
value.
836+
timedelta_goal : timedelta, optional
837+
Convenience argument, use instead of ``goal``. The end condition for the
838+
calculation. Stop when the current time is larger or equal than
839+
``start_time + timedelta_goal``.
803840
"""
804-
goal = _goal(learner, goal, loss_goal, npoints_goal, allow_running_forever=False)
841+
goal = _goal(
842+
learner,
843+
goal,
844+
loss_goal,
845+
npoints_goal,
846+
datetime_goal,
847+
timedelta_goal,
848+
allow_running_forever=False,
849+
)
805850
while not goal(learner):
806851
xs, _ = learner.ask(1)
807852
for x in xs:
@@ -942,8 +987,12 @@ def __call__(self, _):
942987

943988

944989
def auto_goal(
945-
goal: GoalTypes,
946-
learner: BaseLearner,
990+
*,
991+
loss: float | None = None,
992+
npoints: int | None = None,
993+
datetime: datetime | None = None,
994+
timedelta: timedelta | None = None,
995+
learner: BaseLearner | None = None,
947996
allow_running_forever: bool = True,
948997
) -> Callable[[BaseLearner], bool]:
949998
"""Extract a goal from the learners.
@@ -954,7 +1003,6 @@ def auto_goal(
9541003
The goal to extract. Can be a callable, an integer, a float, a datetime,
9551004
a timedelta or None.
9561005
If the type of `goal` is:
957-
9581006
* ``callable``, it is returned as is.
9591007
* ``int``, the goal is reached after that many points have been added.
9601008
* ``float``, the goal is reached when the learner has reached a loss
@@ -980,23 +1028,36 @@ def auto_goal(
9801028
-------
9811029
Callable[[adaptive.BaseLearner], bool]
9821030
"""
983-
if callable(goal):
984-
return goal
985-
if isinstance(goal, float):
986-
return lambda learner: learner.loss() <= goal
1031+
kw = dict(
1032+
loss=loss,
1033+
npoints=npoints,
1034+
datetime=datetime,
1035+
timedelta=timedelta,
1036+
allow_running_forever=allow_running_forever,
1037+
)
1038+
opts = (loss, npoints, datetime, timedelta) # all are mutually exclusive
1039+
if sum(v is not None for v in opts) > 1:
1040+
raise ValueError(
1041+
"Only one of loss, npoints, datetime, timedelta can be specified."
1042+
)
1043+
1044+
if loss is not None:
1045+
return lambda learner: learner.loss() <= loss
9871046
if isinstance(learner, BalancingLearner):
9881047
# Note that the float loss goal is more efficiently implemented in the
9891048
# BalancingLearner itself. That is why the previous if statement is
9901049
# above this one.
991-
goals = [auto_goal(goal, l, allow_running_forever) for l in learner.learners]
1050+
goals = [auto_goal(learner=l, **kw) for l in learner.learners]
9921051
return lambda learner: all(goal(l) for l, goal in zip(learner.learners, goals))
993-
if isinstance(goal, int):
994-
return lambda learner: learner.npoints >= goal
995-
if isinstance(goal, (timedelta, datetime)):
996-
return _TimeGoal(goal)
1052+
if npoints is not None:
1053+
return lambda learner: learner.npoints >= npoints
1054+
if datetime is not None:
1055+
return _TimeGoal(datetime)
1056+
if timedelta is not None:
1057+
return _TimeGoal(timedelta)
9971058
if isinstance(learner, DataSaver):
998-
return auto_goal(goal, learner.learner, allow_running_forever)
999-
if goal is None:
1059+
return auto_goal(**kw, learner=learner.learner)
1060+
if all(v is None for v in opts):
10001061
if isinstance(learner, SequenceLearner):
10011062
return SequenceLearner.done
10021063
if isinstance(learner, IntegratorLearner):
@@ -1012,17 +1073,32 @@ def auto_goal(
10121073

10131074

10141075
def _goal(
1015-
learner: BaseLearner,
1016-
goal: GoalTypes,
1076+
learner: BaseLearner | None,
1077+
goal: Callable[[BaseLearner], bool] | None,
10171078
loss_goal: float | None,
10181079
npoints_goal: int | None,
1080+
datetime_goal: datetime | None,
1081+
timedelta_goal: timedelta | None,
10191082
allow_running_forever: bool,
10201083
):
1021-
# goal, loss_goal, npoints_goal are mutually exclusive, only one can be not None
1022-
if goal is not None and (loss_goal is not None or npoints_goal is not None):
1023-
raise ValueError("Either goal, loss_goal, or npoints_goal can be specified.")
1024-
if loss_goal is not None:
1025-
goal = float(loss_goal)
1026-
if npoints_goal is not None:
1027-
goal = int(npoints_goal)
1028-
return auto_goal(goal, learner, allow_running_forever)
1084+
if callable(goal):
1085+
return goal
1086+
1087+
if goal is not None and (
1088+
loss_goal is not None
1089+
or npoints_goal is not None
1090+
or datetime_goal is not None
1091+
or timedelta_goal is not None
1092+
):
1093+
raise ValueError(
1094+
"Either goal, loss_goal, npoints_goal, datetime_goal or"
1095+
" timedelta_goal can be specified, not multiple."
1096+
)
1097+
return auto_goal(
1098+
learner=learner,
1099+
loss=loss_goal,
1100+
npoints=npoints_goal,
1101+
datetime=datetime_goal,
1102+
timedelta=timedelta_goal,
1103+
allow_running_forever=allow_running_forever,
1104+
)

adaptive/tests/test_balancing_learner.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -50,15 +50,15 @@ def test_ask_0(strategy):
5050

5151

5252
@pytest.mark.parametrize(
53-
"strategy, goal",
53+
"strategy, goal_type, goal",
5454
[
55-
("loss", 0.1),
56-
("loss_improvements", 0.1),
57-
("npoints", lambda bl: all(l.npoints > 10 for l in bl.learners)),
58-
("cycle", 0.1),
55+
("loss", "loss_goal", 0.1),
56+
("loss_improvements", "loss_goal", 0.1),
57+
("npoints", "goal", lambda bl: all(l.npoints > 10 for l in bl.learners)),
58+
("cycle", "loss_goal", 0.1),
5959
],
6060
)
61-
def test_strategies(strategy, goal):
61+
def test_strategies(strategy, goal_type, goal):
6262
learners = [Learner1D(lambda x: x, bounds=(-1, 1)) for i in range(10)]
6363
learner = BalancingLearner(learners, strategy=strategy)
64-
simple(learner, goal=goal)
64+
simple(learner, **{goal_type: goal})

adaptive/tests/test_runner.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -163,36 +163,36 @@ def test_default_executor():
163163

164164
def test_auto_goal():
165165
learner = Learner1D(linear, (-1, 1))
166-
simple(learner, auto_goal(4, learner))
166+
simple(learner, auto_goal(npoints=4))
167167
assert learner.npoints == 4
168168

169169
learner = Learner1D(linear, (-1, 1))
170-
simple(learner, auto_goal(0.5, learner))
170+
simple(learner, auto_goal(loss=0.5))
171171
assert learner.loss() <= 0.5
172172

173173
learner = SequenceLearner(linear, np.linspace(-1, 1))
174-
simple(learner, auto_goal(None, learner))
174+
simple(learner, auto_goal(learner=learner))
175175
assert learner.done()
176176

177177
learner = IntegratorLearner(linear, bounds=(0, 1), tol=0.1)
178-
simple(learner, auto_goal(None, learner))
178+
simple(learner, auto_goal(learner=learner))
179179
assert learner.done()
180180

181181
learner = Learner1D(linear, (-1, 1))
182182
learner = DataSaver(learner, lambda x: x)
183-
simple(learner, auto_goal(4, learner))
183+
simple(learner, auto_goal(npoints=4, learner=learner))
184184
assert learner.npoints == 4
185185

186186
learner1 = Learner1D(linear, (-1, 1))
187187
learner2 = Learner1D(linear, (-2, 2))
188188
balancing_learner = BalancingLearner([learner1, learner2])
189-
simple(balancing_learner, auto_goal(4, balancing_learner))
189+
simple(balancing_learner, auto_goal(npoints=4, learner=balancing_learner))
190190
assert learner1.npoints == 4 and learner2.npoints == 4
191191

192192
learner1 = Learner1D(linear, bounds=(0, 1))
193193
learner1 = DataSaver(learner1, lambda x: x)
194194
learner2 = Learner1D(linear, bounds=(0, 1))
195195
learner2 = DataSaver(learner2, lambda x: x)
196196
balancing_learner = BalancingLearner([learner1, learner2])
197-
simple(balancing_learner, auto_goal(10, balancing_learner))
197+
simple(balancing_learner, auto_goal(npoints=10, learner=balancing_learner))
198198
assert learner1.npoints == 10 and learner2.npoints == 10

data/periodic_example.p

477 Bytes
Binary file not shown.

docs/source/_static/logo_docs.webm

529 KB
Binary file not shown.

test.py

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from adaptive import SequenceLearner, runner
2+
from toolz import identity
3+
from time import time
4+
seq = range(int(1e6))
5+
t = time()
6+
learner = SequenceLearner(identity, seq)
7+
runner.simple(learner, SequenceLearner.done)
8+
print(time() - t)

0 commit comments

Comments
 (0)