14
14
import warnings
15
15
from contextlib import suppress
16
16
from datetime import datetime , timedelta
17
- from typing import Any , Callable
17
+ from typing import Any , Callable , Union
18
18
19
19
import loky
20
20
21
21
from adaptive import BalancingLearner , BaseLearner , IntegratorLearner , SequenceLearner
22
22
from adaptive .notebook_integration import in_ipynb , live_info , live_plot
23
23
24
+ try :
25
+ from typing import TypeAlias
26
+ except ModuleNotFoundError :
27
+ # Python <3.10
28
+ from typing_extensions import TypeAlias
29
+
24
30
try :
25
31
import ipyparallel
26
32
60
66
# and https://github.com/python-adaptive/adaptive/issues/301
61
67
_default_executor = loky .get_reusable_executor
62
68
69
+ GoalTypes : TypeAlias = Union [
70
+ Callable [[BaseLearner ], bool ], int , float , datetime , timedelta , None
71
+ ]
72
+
63
73
64
74
class BaseRunner (metaclass = abc .ABCMeta ):
65
75
r"""Base class for runners that use `concurrent.futures.Executors`.
@@ -120,8 +130,10 @@ class BaseRunner(metaclass=abc.ABCMeta):
120
130
def __init__ (
121
131
self ,
122
132
learner ,
123
- goal ,
124
133
* ,
134
+ goal : GoalTypes = None ,
135
+ loss_goal : float | None = None ,
136
+ npoints_goal : int | None = None ,
125
137
executor = None ,
126
138
ntasks = None ,
127
139
log = False ,
@@ -132,7 +144,7 @@ def __init__(
132
144
):
133
145
134
146
self .executor = _ensure_executor (executor )
135
- self .goal = auto_goal ( goal , learner , allow_running_forever )
147
+ self .goal = _goal ( learner , goal , loss_goal , npoints_goal , allow_running_forever )
136
148
137
149
self ._max_tasks = ntasks
138
150
@@ -376,8 +388,10 @@ class BlockingRunner(BaseRunner):
376
388
def __init__ (
377
389
self ,
378
390
learner ,
379
- goal ,
380
391
* ,
392
+ goal : GoalTypes = None ,
393
+ loss_goal : float | None = None ,
394
+ npoints_goal : int | None = None ,
381
395
executor = None ,
382
396
ntasks = None ,
383
397
log = False ,
@@ -389,7 +403,9 @@ def __init__(
389
403
raise ValueError ("Coroutine functions can only be used with 'AsyncRunner'." )
390
404
super ().__init__ (
391
405
learner ,
392
- goal ,
406
+ goal = goal ,
407
+ loss_goal = loss_goal ,
408
+ npoints_goal = npoints_goal ,
393
409
executor = executor ,
394
410
ntasks = ntasks ,
395
411
log = log ,
@@ -508,8 +524,10 @@ class AsyncRunner(BaseRunner):
508
524
def __init__ (
509
525
self ,
510
526
learner ,
511
- goal = None ,
512
527
* ,
528
+ goal : GoalTypes = None ,
529
+ loss_goal : float | None = None ,
530
+ npoints_goal : int | None = None ,
513
531
executor = None ,
514
532
ntasks = None ,
515
533
log = False ,
@@ -537,7 +555,9 @@ def __init__(
537
555
538
556
super ().__init__ (
539
557
learner ,
540
- goal ,
558
+ goal = goal ,
559
+ loss_goal = loss_goal ,
560
+ npoints_goal = npoints_goal ,
541
561
executor = executor ,
542
562
ntasks = ntasks ,
543
563
log = log ,
@@ -717,7 +737,13 @@ async def _saver():
717
737
Runner = AsyncRunner
718
738
719
739
720
- def simple (learner , goal ):
740
+ def simple (
741
+ learner ,
742
+ * ,
743
+ goal : GoalTypes = None ,
744
+ loss_goal : float | None = None ,
745
+ npoints_goal : int | None = None ,
746
+ ):
721
747
"""Run the learner until the goal is reached.
722
748
723
749
Requests a single point from the learner, evaluates
@@ -736,7 +762,7 @@ def simple(learner, goal):
736
762
The end condition for the calculation. This function must take the
737
763
learner as its sole argument, and return True if we should stop.
738
764
"""
739
- goal = auto_goal ( goal , learner )
765
+ goal = _goal ( learner , goal , loss_goal , npoints_goal , allow_running_forever = False )
740
766
while not goal (learner ):
741
767
xs , _ = learner .ask (1 )
742
768
for x in xs :
@@ -871,14 +897,13 @@ def __call__(self, _):
871
897
if self .start_time is None :
872
898
self .start_time = datetime .now ()
873
899
return datetime .now () - self .start_time > self .dt
874
- elif isinstance (self .dt , datetime ):
900
+ if isinstance (self .dt , datetime ):
875
901
return datetime .now () > self .dt
876
- else :
877
- raise TypeError (f"`dt={ self .dt } ` is not a datetime or timedelta." )
902
+ raise TypeError (f"`dt={ self .dt } ` is not a datetime or timedelta." )
878
903
879
904
880
905
def auto_goal (
881
- goal : Callable [[ BaseLearner ], bool ] | int | float | datetime | timedelta | None ,
906
+ goal : GoalTypes ,
882
907
learner : BaseLearner ,
883
908
allow_running_forever : bool = True ,
884
909
):
@@ -935,12 +960,28 @@ def auto_goal(
935
960
return SequenceLearner .done
936
961
if isinstance (learner , IntegratorLearner ):
937
962
return IntegratorLearner .done
938
- warnings .warn ("Goal is None which means the learners continue forever!" )
939
- if allow_running_forever :
940
- return lambda _ : False
941
- else :
963
+ if not allow_running_forever :
942
964
raise ValueError (
943
965
"Goal is None which means the learners"
944
966
" continue forever and this is not allowed."
945
967
)
968
+ warnings .warn ("Goal is None which means the learners continue forever!" )
969
+ return lambda _ : False
946
970
raise ValueError ("Cannot determine goal from {goal}." )
971
+
972
+
973
+ def _goal (
974
+ learner : BaseLearner ,
975
+ goal : GoalTypes ,
976
+ loss_goal : float | None ,
977
+ npoints_goal : int | None ,
978
+ allow_running_forever : bool ,
979
+ ):
980
+ # goal, loss_goal, npoints_goal are mutually exclusive, only one can be not None
981
+ if goal is not None and (loss_goal is not None or npoints_goal is not None ):
982
+ raise ValueError ("Either goal, loss_goal, or npoints_goal can be specified." )
983
+ if loss_goal is not None :
984
+ goal = float (loss_goal )
985
+ if npoints_goal is not None :
986
+ goal = int (npoints_goal )
987
+ return auto_goal (goal , learner , allow_running_forever )
0 commit comments