9
9
import logging
10
10
import warnings
11
11
from types import TracebackType
12
- from typing import Any , Callable , Generator , List , Optional , Sequence , Tuple , Type , cast
12
+ from typing import (
13
+ Any ,
14
+ AsyncIterator ,
15
+ Callable ,
16
+ Generator ,
17
+ List ,
18
+ Optional ,
19
+ Sequence ,
20
+ Tuple ,
21
+ Type ,
22
+ cast ,
23
+ )
13
24
14
25
from ..datastructures import Headers , HeadersLike
15
26
from ..exceptions import (
27
+ ConnectionClosed ,
16
28
InvalidHandshake ,
17
29
InvalidHeader ,
18
30
InvalidMessage ,
@@ -413,12 +425,22 @@ class Connect:
413
425
Awaiting :func:`connect` yields a :class:`WebSocketClientProtocol` which
414
426
can then be used to send and receive messages.
415
427
416
- :func:`connect` can also be used as a asynchronous context manager::
428
+ :func:`connect` can be used as a asynchronous context manager::
417
429
418
430
async with connect(...) as websocket:
419
431
...
420
432
421
- In that case, the connection is closed when exiting the context.
433
+ The connection is closed automatically when exiting the context.
434
+
435
+ :func:`connect` can be used as an infinite asynchronous iterator to
436
+ reconnect automatically on errors::
437
+
438
+ async for websocket in connect(...):
439
+ ...
440
+
441
+ As above, connections are closed automatically. Connection attempts are
442
+ delayed with exponential backoff, starting at two seconds and increasing
443
+ up to one minute.
422
444
423
445
:func:`connect` is a wrapper around the event loop's
424
446
:meth:`~asyncio.loop.create_connection` method. Unknown keyword arguments
@@ -577,6 +599,10 @@ def __init__(
577
599
)
578
600
579
601
self .open_timeout = open_timeout
602
+ if logger is None :
603
+ logger = logging .getLogger ("websockets.client" )
604
+ self .logger = logger
605
+
580
606
# This is a coroutine function.
581
607
self ._create_connection = create_connection
582
608
self ._wsuri = wsuri
@@ -615,7 +641,45 @@ def handle_redirect(self, uri: str) -> None:
615
641
# Set the new WebSocket URI. This suffices for same-origin redirects.
616
642
self ._wsuri = new_wsuri
617
643
618
- # async with connect(...)
644
+ BACKOFF_MIN = 2
645
+ BACKOFF_MAX = 60
646
+ BACKOFF_FACTOR = 1.5
647
+
648
+ # async for ... in connect(...):
649
+
650
+ async def __aiter__ (self ) -> AsyncIterator [WebSocketClientProtocol ]:
651
+ backoff_delay = self .BACKOFF_MIN
652
+ while True :
653
+ try :
654
+ protocol = await self
655
+ except Exception :
656
+ # Connection timed out - increase backoff delay
657
+ backoff_delay = min (int (1.5 * backoff_delay ), self .BACKOFF_MAX )
658
+ self .logger .error (
659
+ "! connect failed; retrying in %d seconds" , backoff_delay
660
+ )
661
+ await asyncio .sleep (backoff_delay )
662
+ continue
663
+ else :
664
+ # Connection succeeded - reset backoff delay
665
+ backoff_delay = self .BACKOFF_MIN
666
+
667
+ try :
668
+ yield protocol
669
+ except GeneratorExit :
670
+ raise
671
+ except ConnectionClosed :
672
+ self .logger .debug ("! connection closed; reconnecting" , exc_info = True )
673
+ # Remove this branch when dropping support for Python < 3.8
674
+ # because CancelledError no longer inherits Exception.
675
+ except asyncio .CancelledError :
676
+ raise
677
+ except Exception :
678
+ self .logger .warning ("! an error occurred; reconnecting" , exc_info = True )
679
+ finally :
680
+ await protocol .close ()
681
+
682
+ # async with connect(...) as ...:
619
683
620
684
async def __aenter__ (self ) -> WebSocketClientProtocol :
621
685
return await self
@@ -628,7 +692,7 @@ async def __aexit__(
628
692
) -> None :
629
693
await self .protocol .close ()
630
694
631
- # await connect(...)
695
+ # ... = await connect(...)
632
696
633
697
def __await__ (self ) -> Generator [Any , None , WebSocketClientProtocol ]:
634
698
# Create a suitable iterator by calling __await__ on a coroutine.
@@ -665,7 +729,7 @@ async def __await_impl__(self) -> WebSocketClientProtocol:
665
729
else :
666
730
raise SecurityError ("too many redirects" )
667
731
668
- # yield from connect(...)
732
+ # ... = yield from connect(...)
669
733
670
734
__iter__ = __await__
671
735
0 commit comments