72
72
# and https://github.com/python-adaptive/adaptive/issues/301
73
73
_default_executor = loky .get_reusable_executor
74
74
75
- GoalTypes : TypeAlias = Union [
75
+ _GoalTypes : TypeAlias = Union [
76
76
Callable [[BaseLearner ], bool ], int , float , datetime , timedelta , None
77
77
]
78
78
@@ -86,15 +86,22 @@ class BaseRunner(metaclass=abc.ABCMeta):
86
86
goal : callable, optional
87
87
The end condition for the calculation. This function must take
88
88
the learner as its sole argument, and return True when we should
89
- stop requesting more points. (Advanced use) Instead of providing a
90
- function, see `auto_goal` for other types that are accepted here.
89
+ stop requesting more points.
91
90
loss_goal : float, optional
92
91
Convenience argument, use instead of ``goal``. The end condition for the
93
92
calculation. Stop when the loss is smaller than this value.
94
93
npoints_goal : int, optional
95
94
Convenience argument, use instead of ``goal``. The end condition for the
96
95
calculation. Stop when the number of points is larger or
97
96
equal than this value.
97
+ datetime_goal : datetime, optional
98
+ Convenience argument, use instead of ``goal``. The end condition for the
99
+ calculation. Stop when the current time is larger or equal than this
100
+ value.
101
+ timedelta_goal : timedelta, optional
102
+ Convenience argument, use instead of ``goal``. The end condition for the
103
+ calculation. Stop when the current time is larger or equal than
104
+ ``start_time + timedelta_goal``.
98
105
executor : `concurrent.futures.Executor`, `distributed.Client`,\
99
106
`mpi4py.futures.MPIPoolExecutor`, `ipyparallel.Client` or\
100
107
`loky.get_reusable_executor`, optional
@@ -144,10 +151,12 @@ class BaseRunner(metaclass=abc.ABCMeta):
144
151
def __init__ (
145
152
self ,
146
153
learner ,
147
- goal : GoalTypes = None ,
154
+ goal : Callable [[ BaseLearner ], bool ] | None = None ,
148
155
* ,
149
156
loss_goal : float | None = None ,
150
157
npoints_goal : int | None = None ,
158
+ datetime_goal : datetime .datetime | None = None ,
159
+ timedelta_goal : datetime .timedelta | None = None ,
151
160
executor = None ,
152
161
ntasks = None ,
153
162
log = False ,
@@ -158,7 +167,15 @@ def __init__(
158
167
):
159
168
160
169
self .executor = _ensure_executor (executor )
161
- self .goal = _goal (learner , goal , loss_goal , npoints_goal , allow_running_forever )
170
+ self .goal = _goal (
171
+ learner ,
172
+ goal ,
173
+ loss_goal ,
174
+ npoints_goal ,
175
+ datetime_goal ,
176
+ timedelta_goal ,
177
+ allow_running_forever ,
178
+ )
162
179
163
180
self ._max_tasks = ntasks
164
181
@@ -348,8 +365,7 @@ class BlockingRunner(BaseRunner):
348
365
goal : callable
349
366
The end condition for the calculation. This function must take
350
367
the learner as its sole argument, and return True when we should
351
- stop requesting more points. (Advanced use) Instead of providing a
352
- function, see `auto_goal` for other types that are accepted here.
368
+ stop requesting more points.
353
369
loss_goal : float
354
370
Convenience argument, use instead of ``goal``. The end condition for the
355
371
calculation. Stop when the loss is smaller than this value.
@@ -410,10 +426,12 @@ class BlockingRunner(BaseRunner):
410
426
def __init__ (
411
427
self ,
412
428
learner ,
413
- goal : GoalTypes = None ,
429
+ goal : Callable [[ BaseLearner ], bool ] | None = None ,
414
430
* ,
415
431
loss_goal : float | None = None ,
416
432
npoints_goal : int | None = None ,
433
+ datetime_goal : datetime .datetime | None = None ,
434
+ timedelta_goal : datetime .timedelta | None = None ,
417
435
executor = None ,
418
436
ntasks = None ,
419
437
log = False ,
@@ -481,8 +499,7 @@ class AsyncRunner(BaseRunner):
481
499
goal : callable, optional
482
500
The end condition for the calculation. This function must take
483
501
the learner as its sole argument, and return True when we should
484
- stop requesting more points. (Advanced use) Instead of providing a
485
- function, see `auto_goal` for other types that are accepted here.
502
+ stop requesting more points.
486
503
If not provided, the runner will run forever (or stop when no more
487
504
points can be added), or until ``self.task.cancel()`` is called.
488
505
loss_goal : float, optional
@@ -492,6 +509,14 @@ class AsyncRunner(BaseRunner):
492
509
Convenience argument, use instead of ``goal``. The end condition for the
493
510
calculation. Stop when the number of points is larger or
494
511
equal than this value.
512
+ datetime_goal : datetime, optional
513
+ Convenience argument, use instead of ``goal``. The end condition for the
514
+ calculation. Stop when the current time is larger or equal than this
515
+ value.
516
+ timedelta_goal : timedelta, optional
517
+ Convenience argument, use instead of ``goal``. The end condition for the
518
+ calculation. Stop when the current time is larger or equal than
519
+ ``start_time + timedelta_goal``.
495
520
executor : `concurrent.futures.Executor`, `distributed.Client`,\
496
521
`mpi4py.futures.MPIPoolExecutor`, `ipyparallel.Client` or\
497
522
`loky.get_reusable_executor`, optional
@@ -555,10 +580,12 @@ class AsyncRunner(BaseRunner):
555
580
def __init__ (
556
581
self ,
557
582
learner ,
558
- goal : GoalTypes = None ,
583
+ goal : Callable [[ BaseLearner ], bool ] | None = None ,
559
584
* ,
560
585
loss_goal : float | None = None ,
561
586
npoints_goal : int | None = None ,
587
+ datetime_goal : datetime .datetime | None = None ,
588
+ timedelta_goal : datetime .timedelta | None = None ,
562
589
executor = None ,
563
590
ntasks = None ,
564
591
log = False ,
@@ -770,10 +797,12 @@ async def _saver():
770
797
771
798
def simple (
772
799
learner ,
773
- goal : GoalTypes = None ,
800
+ goal : Callable [[ BaseLearner ], bool ] | None = None ,
774
801
* ,
775
802
loss_goal : float | None = None ,
776
803
npoints_goal : int | None = None ,
804
+ datetime_goal : datetime .datetime | None = None ,
805
+ timedelta_goal : datetime .timedelta | None = None ,
777
806
):
778
807
"""Run the learner until the goal is reached.
779
808
@@ -800,8 +829,24 @@ def simple(
800
829
Convenience argument, use instead of ``goal``. The end condition for the
801
830
calculation. Stop when the number of points is larger or
802
831
equal than this value.
832
+ datetime_goal : datetime, optional
833
+ Convenience argument, use instead of ``goal``. The end condition for the
834
+ calculation. Stop when the current time is larger or equal than this
835
+ value.
836
+ timedelta_goal : timedelta, optional
837
+ Convenience argument, use instead of ``goal``. The end condition for the
838
+ calculation. Stop when the current time is larger or equal than
839
+ ``start_time + timedelta_goal``.
803
840
"""
804
- goal = _goal (learner , goal , loss_goal , npoints_goal , allow_running_forever = False )
841
+ goal = _goal (
842
+ learner ,
843
+ goal ,
844
+ loss_goal ,
845
+ npoints_goal ,
846
+ datetime_goal ,
847
+ timedelta_goal ,
848
+ allow_running_forever = False ,
849
+ )
805
850
while not goal (learner ):
806
851
xs , _ = learner .ask (1 )
807
852
for x in xs :
@@ -942,8 +987,12 @@ def __call__(self, _):
942
987
943
988
944
989
def auto_goal (
945
- goal : GoalTypes ,
946
- learner : BaseLearner ,
990
+ * ,
991
+ loss : float | None = None ,
992
+ npoints : int | None = None ,
993
+ datetime : datetime | None = None ,
994
+ timedelta : timedelta | None = None ,
995
+ learner : BaseLearner | None = None ,
947
996
allow_running_forever : bool = True ,
948
997
) -> Callable [[BaseLearner ], bool ]:
949
998
"""Extract a goal from the learners.
@@ -954,7 +1003,6 @@ def auto_goal(
954
1003
The goal to extract. Can be a callable, an integer, a float, a datetime,
955
1004
a timedelta or None.
956
1005
If the type of `goal` is:
957
-
958
1006
* ``callable``, it is returned as is.
959
1007
* ``int``, the goal is reached after that many points have been added.
960
1008
* ``float``, the goal is reached when the learner has reached a loss
@@ -980,23 +1028,36 @@ def auto_goal(
980
1028
-------
981
1029
Callable[[adaptive.BaseLearner], bool]
982
1030
"""
983
- if callable (goal ):
984
- return goal
985
- if isinstance (goal , float ):
986
- return lambda learner : learner .loss () <= goal
1031
+ kw = dict (
1032
+ loss = loss ,
1033
+ npoints = npoints ,
1034
+ datetime = datetime ,
1035
+ timedelta = timedelta ,
1036
+ allow_running_forever = allow_running_forever ,
1037
+ )
1038
+ opts = (loss , npoints , datetime , timedelta ) # all are mutually exclusive
1039
+ if sum (v is not None for v in opts ) > 1 :
1040
+ raise ValueError (
1041
+ "Only one of loss, npoints, datetime, timedelta can be specified."
1042
+ )
1043
+
1044
+ if loss is not None :
1045
+ return lambda learner : learner .loss () <= loss
987
1046
if isinstance (learner , BalancingLearner ):
988
1047
# Note that the float loss goal is more efficiently implemented in the
989
1048
# BalancingLearner itself. That is why the previous if statement is
990
1049
# above this one.
991
- goals = [auto_goal (goal , l , allow_running_forever ) for l in learner .learners ]
1050
+ goals = [auto_goal (learner = l , ** kw ) for l in learner .learners ]
992
1051
return lambda learner : all (goal (l ) for l , goal in zip (learner .learners , goals ))
993
- if isinstance (goal , int ):
994
- return lambda learner : learner .npoints >= goal
995
- if isinstance (goal , (timedelta , datetime )):
996
- return _TimeGoal (goal )
1052
+ if npoints is not None :
1053
+ return lambda learner : learner .npoints >= npoints
1054
+ if datetime is not None :
1055
+ return _TimeGoal (datetime )
1056
+ if timedelta is not None :
1057
+ return _TimeGoal (timedelta )
997
1058
if isinstance (learner , DataSaver ):
998
- return auto_goal (goal , learner .learner , allow_running_forever )
999
- if goal is None :
1059
+ return auto_goal (** kw , learner = learner .learner )
1060
+ if all ( v is None for v in opts ) :
1000
1061
if isinstance (learner , SequenceLearner ):
1001
1062
return SequenceLearner .done
1002
1063
if isinstance (learner , IntegratorLearner ):
@@ -1012,17 +1073,32 @@ def auto_goal(
1012
1073
1013
1074
1014
1075
def _goal (
1015
- learner : BaseLearner ,
1016
- goal : GoalTypes ,
1076
+ learner : BaseLearner | None ,
1077
+ goal : Callable [[ BaseLearner ], bool ] | None ,
1017
1078
loss_goal : float | None ,
1018
1079
npoints_goal : int | None ,
1080
+ datetime_goal : datetime | None ,
1081
+ timedelta_goal : timedelta | None ,
1019
1082
allow_running_forever : bool ,
1020
1083
):
1021
- # goal, loss_goal, npoints_goal are mutually exclusive, only one can be not None
1022
- if goal is not None and (loss_goal is not None or npoints_goal is not None ):
1023
- raise ValueError ("Either goal, loss_goal, or npoints_goal can be specified." )
1024
- if loss_goal is not None :
1025
- goal = float (loss_goal )
1026
- if npoints_goal is not None :
1027
- goal = int (npoints_goal )
1028
- return auto_goal (goal , learner , allow_running_forever )
1084
+ if callable (goal ):
1085
+ return goal
1086
+
1087
+ if goal is not None and (
1088
+ loss_goal is not None
1089
+ or npoints_goal is not None
1090
+ or datetime_goal is not None
1091
+ or timedelta_goal is not None
1092
+ ):
1093
+ raise ValueError (
1094
+ "Either goal, loss_goal, npoints_goal, datetime_goal or"
1095
+ " timedelta_goal can be specified, not multiple."
1096
+ )
1097
+ return auto_goal (
1098
+ learner = learner ,
1099
+ loss = loss_goal ,
1100
+ npoints = npoints_goal ,
1101
+ datetime = datetime_goal ,
1102
+ timedelta = timedelta_goal ,
1103
+ allow_running_forever = allow_running_forever ,
1104
+ )
0 commit comments