15
15
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
16
# See the License for the specific language governing permissions and
17
17
# limitations under the License.
18
-
19
-
18
+ import dataclasses
20
19
import functools
21
20
import sys
22
21
import threading
25
24
import warnings
26
25
from abc import ABC , abstractmethod
27
26
from concurrent import futures
28
- from inspect import iscoroutinefunction
27
+
28
+ from . import _utils
29
29
30
30
# Import all built-in retry strategies for easier usage.
31
31
from .retry import retry_base # noqa
50
50
# Import all built-in stop strategies for easier usage.
51
51
from .stop import stop_after_attempt # noqa
52
52
from .stop import stop_after_delay # noqa
53
+ from .stop import stop_before_delay # noqa
53
54
from .stop import stop_all # noqa
54
55
from .stop import stop_any # noqa
55
56
from .stop import stop_never # noqa
89
90
if t .TYPE_CHECKING :
90
91
import types
91
92
93
+ from . import asyncio as tasyncio
92
94
from .retry import RetryBaseT
93
95
from .stop import StopBaseT
94
96
from .wait import WaitBaseT
98
100
WrappedFn = t .TypeVar ("WrappedFn" , bound = t .Callable [..., t .Any ])
99
101
100
102
103
+ dataclass_kwargs = {}
104
+ if sys .version_info >= (3 , 10 ):
105
+ dataclass_kwargs .update ({"slots" : True })
106
+
107
+
108
+ @dataclasses .dataclass (** dataclass_kwargs )
109
+ class IterState :
110
+ actions : t .List [t .Callable [["RetryCallState" ], t .Any ]] = dataclasses .field (
111
+ default_factory = list
112
+ )
113
+ retry_run_result : bool = False
114
+ delay_since_first_attempt : int = 0
115
+ stop_run_result : bool = False
116
+ is_explicit_retry : bool = False
117
+
118
+ def reset (self ) -> None :
119
+ self .actions = []
120
+ self .retry_run_result = False
121
+ self .delay_since_first_attempt = 0
122
+ self .stop_run_result = False
123
+ self .is_explicit_retry = False
124
+
125
+
101
126
class TryAgain (Exception ):
102
127
"""Always retry the executed function when raised."""
103
128
@@ -126,7 +151,9 @@ class BaseAction:
126
151
NAME : t .Optional [str ] = None
127
152
128
153
def __repr__ (self ) -> str :
129
- state_str = ", " .join (f"{ field } ={ getattr (self , field )!r} " for field in self .REPR_FIELDS )
154
+ state_str = ", " .join (
155
+ f"{ field } ={ getattr (self , field )!r} " for field in self .REPR_FIELDS
156
+ )
130
157
return f"{ self .__class__ .__name__ } ({ state_str } )"
131
158
132
159
def __str__ (self ) -> str :
@@ -222,10 +249,14 @@ def copy(
222
249
retry : t .Union [retry_base , object ] = _unset ,
223
250
before : t .Union [t .Callable [["RetryCallState" ], None ], object ] = _unset ,
224
251
after : t .Union [t .Callable [["RetryCallState" ], None ], object ] = _unset ,
225
- before_sleep : t .Union [t .Optional [t .Callable [["RetryCallState" ], None ]], object ] = _unset ,
252
+ before_sleep : t .Union [
253
+ t .Optional [t .Callable [["RetryCallState" ], None ]], object
254
+ ] = _unset ,
226
255
reraise : t .Union [bool , object ] = _unset ,
227
256
retry_error_cls : t .Union [t .Type [RetryError ], object ] = _unset ,
228
- retry_error_callback : t .Union [t .Optional [t .Callable [["RetryCallState" ], t .Any ]], object ] = _unset ,
257
+ retry_error_callback : t .Union [
258
+ t .Optional [t .Callable [["RetryCallState" ], t .Any ]], object
259
+ ] = _unset ,
229
260
) -> "BaseRetrying" :
230
261
"""Copy this object with some parameters changed if needed."""
231
262
return self .__class__ (
@@ -238,7 +269,9 @@ def copy(
238
269
before_sleep = _first_set (before_sleep , self .before_sleep ),
239
270
reraise = _first_set (reraise , self .reraise ),
240
271
retry_error_cls = _first_set (retry_error_cls , self .retry_error_cls ),
241
- retry_error_callback = _first_set (retry_error_callback , self .retry_error_callback ),
272
+ retry_error_callback = _first_set (
273
+ retry_error_callback , self .retry_error_callback
274
+ ),
242
275
)
243
276
244
277
def __repr__ (self ) -> str :
@@ -280,21 +313,37 @@ def statistics(self) -> t.Dict[str, t.Any]:
280
313
self ._local .statistics = t .cast (t .Dict [str , t .Any ], {})
281
314
return self ._local .statistics
282
315
316
+ @property
317
+ def iter_state (self ) -> IterState :
318
+ try :
319
+ return self ._local .iter_state # type: ignore[no-any-return]
320
+ except AttributeError :
321
+ self ._local .iter_state = IterState ()
322
+ return self ._local .iter_state
323
+
283
324
def wraps (self , f : WrappedFn ) -> WrappedFn :
284
325
"""Wrap a function for retrying.
285
326
286
327
:param f: A function to wraps for retrying.
287
328
"""
288
329
289
- @functools .wraps (f )
330
+ @functools .wraps (
331
+ f , functools .WRAPPER_ASSIGNMENTS + ("__defaults__" , "__kwdefaults__" )
332
+ )
290
333
def wrapped_f (* args : t .Any , ** kw : t .Any ) -> t .Any :
291
- return self (f , * args , ** kw )
334
+ # Always create a copy to prevent overwriting the local contexts when
335
+ # calling the same wrapped functions multiple times in the same stack
336
+ copy = self .copy ()
337
+ wrapped_f .statistics = copy .statistics # type: ignore[attr-defined]
338
+ return copy (f , * args , ** kw )
292
339
293
340
def retry_with (* args : t .Any , ** kwargs : t .Any ) -> WrappedFn :
294
341
return self .copy (* args , ** kwargs ).wraps (f )
295
342
343
+ # Preserve attributes
296
344
wrapped_f .retry = self # type: ignore[attr-defined]
297
345
wrapped_f .retry_with = retry_with # type: ignore[attr-defined]
346
+ wrapped_f .statistics = {} # type: ignore[attr-defined]
298
347
299
348
return wrapped_f # type: ignore[return-value]
300
349
@@ -304,42 +353,89 @@ def begin(self) -> None:
304
353
self .statistics ["attempt_number" ] = 1
305
354
self .statistics ["idle_for" ] = 0
306
355
356
+ def _add_action_func (self , fn : t .Callable [..., t .Any ]) -> None :
357
+ self .iter_state .actions .append (fn )
358
+
359
+ def _run_retry (self , retry_state : "RetryCallState" ) -> None :
360
+ self .iter_state .retry_run_result = self .retry (retry_state )
361
+
362
+ def _run_wait (self , retry_state : "RetryCallState" ) -> None :
363
+ if self .wait :
364
+ sleep = self .wait (retry_state )
365
+ else :
366
+ sleep = 0.0
367
+
368
+ retry_state .upcoming_sleep = sleep
369
+
370
+ def _run_stop (self , retry_state : "RetryCallState" ) -> None :
371
+ self .statistics ["delay_since_first_attempt" ] = retry_state .seconds_since_start
372
+ self .iter_state .stop_run_result = self .stop (retry_state )
373
+
307
374
def iter (self , retry_state : "RetryCallState" ) -> t .Union [DoAttempt , DoSleep , t .Any ]: # noqa
375
+ self ._begin_iter (retry_state )
376
+ result = None
377
+ for action in self .iter_state .actions :
378
+ result = action (retry_state )
379
+ return result
380
+
381
+ def _begin_iter (self , retry_state : "RetryCallState" ) -> None : # noqa
382
+ self .iter_state .reset ()
383
+
308
384
fut = retry_state .outcome
309
385
if fut is None :
310
386
if self .before is not None :
311
- self .before (retry_state )
312
- return DoAttempt ()
387
+ self ._add_action_func (self .before )
388
+ self ._add_action_func (lambda rs : DoAttempt ())
389
+ return
313
390
314
- is_explicit_retry = fut .failed and isinstance (fut .exception (), TryAgain )
315
- if not (is_explicit_retry or self .retry (retry_state )):
316
- return fut .result ()
391
+ self .iter_state .is_explicit_retry = fut .failed and isinstance (
392
+ fut .exception (), TryAgain
393
+ )
394
+ if not self .iter_state .is_explicit_retry :
395
+ self ._add_action_func (self ._run_retry )
396
+ self ._add_action_func (self ._post_retry_check_actions )
397
+
398
+ def _post_retry_check_actions (self , retry_state : "RetryCallState" ) -> None :
399
+ if not (self .iter_state .is_explicit_retry or self .iter_state .retry_run_result ):
400
+ self ._add_action_func (lambda rs : rs .outcome .result ())
401
+ return
317
402
318
403
if self .after is not None :
319
- self .after ( retry_state )
404
+ self ._add_action_func ( self . after )
320
405
321
- self .statistics ["delay_since_first_attempt" ] = retry_state .seconds_since_start
322
- if self .stop (retry_state ):
406
+ self ._add_action_func (self ._run_wait )
407
+ self ._add_action_func (self ._run_stop )
408
+ self ._add_action_func (self ._post_stop_check_actions )
409
+
410
+ def _post_stop_check_actions (self , retry_state : "RetryCallState" ) -> None :
411
+ if self .iter_state .stop_run_result :
323
412
if self .retry_error_callback :
324
- return self .retry_error_callback (retry_state )
325
- retry_exc = self .retry_error_cls (fut )
326
- if self .reraise :
327
- raise retry_exc .reraise ()
328
- raise retry_exc from fut .exception ()
413
+ self ._add_action_func (self .retry_error_callback )
414
+ return
329
415
330
- if self .wait :
331
- sleep = self .wait (retry_state )
332
- else :
333
- sleep = 0.0
334
- retry_state .next_action = RetryAction (sleep )
335
- retry_state .idle_for += sleep
336
- self .statistics ["idle_for" ] += sleep
337
- self .statistics ["attempt_number" ] += 1
416
+ def exc_check (rs : "RetryCallState" ) -> None :
417
+ fut = t .cast (Future , rs .outcome )
418
+ retry_exc = self .retry_error_cls (fut )
419
+ if self .reraise :
420
+ raise retry_exc .reraise ()
421
+ raise retry_exc from fut .exception ()
422
+
423
+ self ._add_action_func (exc_check )
424
+ return
425
+
426
+ def next_action (rs : "RetryCallState" ) -> None :
427
+ sleep = rs .upcoming_sleep
428
+ rs .next_action = RetryAction (sleep )
429
+ rs .idle_for += sleep
430
+ self .statistics ["idle_for" ] += sleep
431
+ self .statistics ["attempt_number" ] += 1
432
+
433
+ self ._add_action_func (next_action )
338
434
339
435
if self .before_sleep is not None :
340
- self .before_sleep ( retry_state )
436
+ self ._add_action_func ( self . before_sleep )
341
437
342
- return DoSleep (sleep )
438
+ self . _add_action_func ( lambda rs : DoSleep (rs . upcoming_sleep ) )
343
439
344
440
def __iter__ (self ) -> t .Generator [AttemptManager , None , None ]:
345
441
self .begin ()
@@ -393,7 +489,7 @@ def __call__(
393
489
return do # type: ignore[no-any-return]
394
490
395
491
396
- if sys .version_info [ 1 ] >= 9 :
492
+ if sys .version_info >= ( 3 , 9 ) :
397
493
FutureGenericT = futures .Future [t .Any ]
398
494
else :
399
495
FutureGenericT = futures .Future
@@ -412,7 +508,9 @@ def failed(self) -> bool:
412
508
return self .exception () is not None
413
509
414
510
@classmethod
415
- def construct (cls , attempt_number : int , value : t .Any , has_exception : bool ) -> "Future" :
511
+ def construct (
512
+ cls , attempt_number : int , value : t .Any , has_exception : bool
513
+ ) -> "Future" :
416
514
"""Construct a new Future object."""
417
515
fut = cls (attempt_number )
418
516
if has_exception :
@@ -453,6 +551,8 @@ def __init__(
453
551
self .idle_for : float = 0.0
454
552
#: Next action as decided by the retry manager
455
553
self .next_action : t .Optional [RetryAction ] = None
554
+ #: Next sleep time as decided by the retry manager.
555
+ self .upcoming_sleep : float = 0.0
456
556
457
557
@property
458
558
def seconds_since_start (self ) -> t .Optional [float ]:
@@ -473,7 +573,10 @@ def set_result(self, val: t.Any) -> None:
473
573
self .outcome , self .outcome_timestamp = fut , ts
474
574
475
575
def set_exception (
476
- self , exc_info : t .Tuple [t .Type [BaseException ], BaseException , "types.TracebackType| None" ]
576
+ self ,
577
+ exc_info : t .Tuple [
578
+ t .Type [BaseException ], BaseException , "types.TracebackType| None"
579
+ ],
477
580
) -> None :
478
581
ts = time .monotonic ()
479
582
fut = Future (self .attempt_number )
@@ -495,24 +598,30 @@ def __repr__(self) -> str:
495
598
496
599
497
600
@t .overload
498
- def retry (func : WrappedFn ) -> WrappedFn :
499
- ...
601
+ def retry (func : WrappedFn ) -> WrappedFn : ...
500
602
501
603
502
604
@t .overload
503
605
def retry (
504
- sleep : t .Callable [[t .Union [int , float ]], t .Optional [ t .Awaitable [None ]]] = sleep ,
606
+ sleep : t .Callable [[t .Union [int , float ]], t .Union [ None , t .Awaitable [None ]]] = sleep ,
505
607
stop : "StopBaseT" = stop_never ,
506
608
wait : "WaitBaseT" = wait_none (),
507
- retry : "RetryBaseT" = retry_if_exception_type (),
508
- before : t .Callable [["RetryCallState" ], None ] = before_nothing ,
509
- after : t .Callable [["RetryCallState" ], None ] = after_nothing ,
510
- before_sleep : t .Optional [t .Callable [["RetryCallState" ], None ]] = None ,
609
+ retry : "t.Union[RetryBaseT, tasyncio.retry.RetryBaseT]" = retry_if_exception_type (),
610
+ before : t .Callable [
611
+ ["RetryCallState" ], t .Union [None , t .Awaitable [None ]]
612
+ ] = before_nothing ,
613
+ after : t .Callable [
614
+ ["RetryCallState" ], t .Union [None , t .Awaitable [None ]]
615
+ ] = after_nothing ,
616
+ before_sleep : t .Optional [
617
+ t .Callable [["RetryCallState" ], t .Union [None , t .Awaitable [None ]]]
618
+ ] = None ,
511
619
reraise : bool = False ,
512
620
retry_error_cls : t .Type ["RetryError" ] = RetryError ,
513
- retry_error_callback : t .Optional [t .Callable [["RetryCallState" ], t .Any ]] = None ,
514
- ) -> t .Callable [[WrappedFn ], WrappedFn ]:
515
- ...
621
+ retry_error_callback : t .Optional [
622
+ t .Callable [["RetryCallState" ], t .Union [t .Any , t .Awaitable [t .Any ]]]
623
+ ] = None ,
624
+ ) -> t .Callable [[WrappedFn ], WrappedFn ]: ...
516
625
517
626
518
627
def retry (* dargs : t .Any , ** dkw : t .Any ) -> t .Any :
@@ -533,9 +642,13 @@ def wrap(f: WrappedFn) -> WrappedFn:
533
642
f"this will probably hang indefinitely (did you mean retry={ f .__class__ .__name__ } (...)?)"
534
643
)
535
644
r : "BaseRetrying"
536
- if iscoroutinefunction (f ):
645
+ if _utils . is_coroutine_callable (f ):
537
646
r = AsyncRetrying (* dargs , ** dkw )
538
- elif tornado and hasattr (tornado .gen , "is_coroutine_function" ) and tornado .gen .is_coroutine_function (f ):
647
+ elif (
648
+ tornado
649
+ and hasattr (tornado .gen , "is_coroutine_function" )
650
+ and tornado .gen .is_coroutine_function (f )
651
+ ):
539
652
r = TornadoRetrying (* dargs , ** dkw )
540
653
else :
541
654
r = Retrying (* dargs , ** dkw )
@@ -545,7 +658,7 @@ def wrap(f: WrappedFn) -> WrappedFn:
545
658
return wrap
546
659
547
660
548
- from pip ._vendor .tenacity ._asyncio import AsyncRetrying # noqa:E402,I100
661
+ from pip ._vendor .tenacity .asyncio import AsyncRetrying # noqa:E402,I100
549
662
550
663
if tornado :
551
664
from pip ._vendor .tenacity .tornadoweb import TornadoRetrying
@@ -570,6 +683,7 @@ def wrap(f: WrappedFn) -> WrappedFn:
570
683
"sleep_using_event" ,
571
684
"stop_after_attempt" ,
572
685
"stop_after_delay" ,
686
+ "stop_before_delay" ,
573
687
"stop_all" ,
574
688
"stop_any" ,
575
689
"stop_never" ,
0 commit comments