-
-
Notifications
You must be signed in to change notification settings - Fork 32k
bpo-33530: Implement Happy Eyeballs in asyncio, v2 #7237
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a755bbb
fc29450
d792c43
f9111d0
ded34e0
b069c95
c5d3a92
38c7caa
70ec96d
cef0a76
73a4a5a
b3a6e1c
632166d
b8d7e41
ecdc83a
5fa5a9b
321e4ac
b4227ed
cffacc7
3e304dd
a8c39b9
f6f2219
f0e37e6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,7 @@ | |
import collections | ||
import collections.abc | ||
import concurrent.futures | ||
import functools | ||
import heapq | ||
import itertools | ||
import logging | ||
|
@@ -40,6 +41,7 @@ | |
from . import futures | ||
from . import protocols | ||
from . import sslproto | ||
from . import staggered | ||
from . import tasks | ||
from . import transports | ||
from .log import logger | ||
|
@@ -147,6 +149,28 @@ def _ipaddr_info(host, port, family, type, proto): | |
return None | ||
|
||
|
||
def _interleave_addrinfos(addrinfos, first_address_family_count=1): | ||
"""Interleave list of addrinfo tuples by family.""" | ||
# Group addresses by family | ||
addrinfos_by_family = collections.OrderedDict() | ||
for addr in addrinfos: | ||
family = addr[0] | ||
if family not in addrinfos_by_family: | ||
addrinfos_by_family[family] = [] | ||
addrinfos_by_family[family].append(addr) | ||
addrinfos_lists = list(addrinfos_by_family.values()) | ||
|
||
reordered = [] | ||
if first_address_family_count > 1: | ||
reordered.extend(addrinfos_lists[0][:first_address_family_count - 1]) | ||
del addrinfos_lists[0][:first_address_family_count - 1] | ||
reordered.extend( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I removed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's OK to keep it here now |
||
a for a in itertools.chain.from_iterable( | ||
itertools.zip_longest(*addrinfos_lists) | ||
) if a is not None) | ||
return reordered | ||
|
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it makes sense to group all Happy Eyeballs-related functions in the same module (in that case the name |
||
def _run_until_complete_cb(fut): | ||
if not fut.cancelled(): | ||
exc = fut.exception() | ||
|
@@ -844,12 +868,49 @@ def _check_sendfile_params(self, sock, file, offset, count): | |
"offset must be a non-negative integer (got {!r})".format( | ||
offset)) | ||
|
||
async def _connect_sock(self, exceptions, addr_info, local_addr_infos=None): | ||
"""Create, bind and connect one socket.""" | ||
my_exceptions = [] | ||
exceptions.append(my_exceptions) | ||
family, type_, proto, _, address = addr_info | ||
sock = None | ||
try: | ||
sock = socket.socket(family=family, type=type_, proto=proto) | ||
sock.setblocking(False) | ||
if local_addr_infos is not None: | ||
for _, _, _, _, laddr in local_addr_infos: | ||
try: | ||
sock.bind(laddr) | ||
break | ||
except OSError as exc: | ||
msg = ( | ||
f'error while attempting to bind on ' | ||
f'address {laddr!r}: ' | ||
f'{exc.strerror.lower()}' | ||
) | ||
exc = OSError(exc.errno, msg) | ||
my_exceptions.append(exc) | ||
else: # all bind attempts failed | ||
raise my_exceptions.pop() | ||
await self.sock_connect(sock, address) | ||
return sock | ||
except OSError as exc: | ||
my_exceptions.append(exc) | ||
if sock is not None: | ||
sock.close() | ||
raise | ||
except: | ||
if sock is not None: | ||
sock.close() | ||
raise | ||
|
||
async def create_connection( | ||
self, protocol_factory, host=None, port=None, | ||
*, ssl=None, family=0, | ||
proto=0, flags=0, sock=None, | ||
local_addr=None, server_hostname=None, | ||
ssl_handshake_timeout=None): | ||
ssl_handshake_timeout=None, | ||
happy_eyeballs_delay=None, interleave=None): | ||
"""Connect to a TCP server. | ||
|
||
Create a streaming transport connection to a given Internet host and | ||
|
@@ -884,6 +945,10 @@ async def create_connection( | |
raise ValueError( | ||
'ssl_handshake_timeout is only meaningful with ssl') | ||
|
||
if happy_eyeballs_delay is not None and interleave is None: | ||
# If using happy eyeballs, default to interleave addresses by family | ||
interleave = 1 | ||
|
||
if host is not None or port is not None: | ||
if sock is not None: | ||
raise ValueError( | ||
|
@@ -902,43 +967,31 @@ async def create_connection( | |
flags=flags, loop=self) | ||
if not laddr_infos: | ||
raise OSError('getaddrinfo() returned empty list') | ||
else: | ||
laddr_infos = None | ||
|
||
if interleave: | ||
infos = _interleave_addrinfos(infos, interleave) | ||
|
||
exceptions = [] | ||
for family, type, proto, cname, address in infos: | ||
try: | ||
sock = socket.socket(family=family, type=type, proto=proto) | ||
sock.setblocking(False) | ||
if local_addr is not None: | ||
for _, _, _, _, laddr in laddr_infos: | ||
try: | ||
sock.bind(laddr) | ||
break | ||
except OSError as exc: | ||
msg = ( | ||
f'error while attempting to bind on ' | ||
f'address {laddr!r}: ' | ||
f'{exc.strerror.lower()}' | ||
) | ||
exc = OSError(exc.errno, msg) | ||
exceptions.append(exc) | ||
else: | ||
sock.close() | ||
sock = None | ||
continue | ||
if self._debug: | ||
logger.debug("connect %r to %r", sock, address) | ||
await self.sock_connect(sock, address) | ||
except OSError as exc: | ||
if sock is not None: | ||
sock.close() | ||
exceptions.append(exc) | ||
except: | ||
if sock is not None: | ||
sock.close() | ||
raise | ||
else: | ||
break | ||
else: | ||
if happy_eyeballs_delay is None: | ||
# not using happy eyeballs | ||
for addrinfo in infos: | ||
try: | ||
sock = await self._connect_sock( | ||
exceptions, addrinfo, laddr_infos) | ||
break | ||
except OSError: | ||
continue | ||
else: # using happy eyeballs | ||
sock, _, _ = await staggered.staggered_race( | ||
(functools.partial(self._connect_sock, | ||
exceptions, addrinfo, laddr_infos) | ||
for addrinfo in infos), | ||
happy_eyeballs_delay, loop=self) | ||
|
||
if sock is None: | ||
exceptions = [exc for sub in exceptions for exc in sub] | ||
if len(exceptions) == 1: | ||
raise exceptions[0] | ||
else: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
"""Support for running coroutines in parallel with staggered start times.""" | ||
|
||
__all__ = 'staggered_race', | ||
|
||
import contextlib | ||
import typing | ||
|
||
from . import events | ||
from . import futures | ||
from . import locks | ||
from . import tasks | ||
|
||
|
||
async def staggered_race( | ||
coro_fns: typing.Iterable[typing.Callable[[], typing.Awaitable]], | ||
delay: typing.Optional[float], | ||
*, | ||
loop: events.AbstractEventLoop = None, | ||
) -> typing.Tuple[ | ||
typing.Any, | ||
typing.Optional[int], | ||
typing.List[typing.Optional[Exception]] | ||
]: | ||
"""Run coroutines with staggered start times and take the first to finish. | ||
This method takes an iterable of coroutine functions. The first one is | ||
started immediately. From then on, whenever the immediately preceding one | ||
fails (raises an exception), or when *delay* seconds has passed, the next | ||
coroutine is started. This continues until one of the coroutines complete | ||
successfully, in which case all others are cancelled, or until all | ||
coroutines fail. | ||
The coroutines provided should be well-behaved in the following way: | ||
* They should only ``return`` if completed successfully. | ||
* They should always raise an exception if they did not complete | ||
successfully. In particular, if they handle cancellation, they should | ||
probably reraise, like this:: | ||
try: | ||
# do work | ||
except asyncio.CancelledError: | ||
# undo partially completed work | ||
raise | ||
Args: | ||
coro_fns: an iterable of coroutine functions, i.e. callables that | ||
return a coroutine object when called. Use ``functools.partial`` or | ||
lambdas to pass arguments. | ||
delay: amount of time, in seconds, between starting coroutines. If | ||
``None``, the coroutines will run sequentially. | ||
loop: the event loop to use. | ||
Returns: | ||
tuple *(winner_result, winner_index, exceptions)* where | ||
- *winner_result*: the result of the winning coroutine, or ``None`` | ||
if no coroutines won. | ||
- *winner_index*: the index of the winning coroutine in | ||
``coro_fns``, or ``None`` if no coroutines won. If the winning | ||
coroutine may return None on success, *winner_index* can be used | ||
to definitively determine whether any coroutine won. | ||
- *exceptions*: list of exceptions returned by the coroutines. | ||
``len(exceptions)`` is equal to the number of coroutines actually | ||
started, and the order is the same as in ``coro_fns``. The winning | ||
coroutine's entry is ``None``. | ||
""" | ||
# TODO: when we have aiter() and anext(), allow async iterables in coro_fns. | ||
loop = loop or events.get_running_loop() | ||
enum_coro_fns = enumerate(coro_fns) | ||
winner_result = None | ||
winner_index = None | ||
exceptions = [] | ||
running_tasks = [] | ||
|
||
async def run_one_coro( | ||
previous_failed: typing.Optional[locks.Event]) -> None: | ||
# Wait for the previous task to finish, or for delay seconds | ||
if previous_failed is not None: | ||
with contextlib.suppress(futures.TimeoutError): | ||
# Use asyncio.wait_for() instead of asyncio.wait() here, so | ||
# that if we get cancelled at this point, Event.wait() is also | ||
# cancelled, otherwise there will be a "Task destroyed but it is | ||
# pending" later. | ||
await tasks.wait_for(previous_failed.wait(), delay) | ||
# Get the next coroutine to run | ||
try: | ||
this_index, coro_fn = next(enum_coro_fns) | ||
except StopIteration: | ||
return | ||
# Start task that will run the next coroutine | ||
this_failed = locks.Event() | ||
next_task = loop.create_task(run_one_coro(this_failed)) | ||
running_tasks.append(next_task) | ||
assert len(running_tasks) == this_index + 2 | ||
# Prepare place to put this coroutine's exceptions if not won | ||
exceptions.append(None) | ||
assert len(exceptions) == this_index + 1 | ||
|
||
try: | ||
result = await coro_fn() | ||
except Exception as e: | ||
exceptions[this_index] = e | ||
this_failed.set() # Kickstart the next coroutine | ||
else: | ||
# Store winner's results | ||
nonlocal winner_index, winner_result | ||
assert winner_index is None | ||
winner_index = this_index | ||
winner_result = result | ||
# Cancel all other tasks. We take care to not cancel the current | ||
# task as well. If we do so, then since there is no `await` after | ||
# here and CancelledError are usually thrown at one, we will | ||
# encounter a curious corner case where the current task will end | ||
# up as done() == True, cancelled() == False, exception() == | ||
# asyncio.CancelledError. This behavior is specified in | ||
# https://bugs.python.org/issue30048 | ||
for i, t in enumerate(running_tasks): | ||
if i != this_index: | ||
t.cancel() | ||
|
||
first_task = loop.create_task(run_one_coro(None)) | ||
running_tasks.append(first_task) | ||
try: | ||
# Wait for a growing list of tasks to all finish: poor man's version of | ||
# curio's TaskGroup or trio's nursery | ||
done_count = 0 | ||
while done_count != len(running_tasks): | ||
done, _ = await tasks.wait(running_tasks) | ||
done_count = len(done) | ||
# If run_one_coro raises an unhandled exception, it's probably a | ||
# programming error, and I want to see it. | ||
if __debug__: | ||
for d in done: | ||
if d.done() and not d.cancelled() and d.exception(): | ||
raise d.exception() | ||
return winner_result, winner_index, exceptions | ||
finally: | ||
# Make sure no tasks are left running if we leave this function | ||
for t in running_tasks: | ||
t.cancel() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
Implemented Happy Eyeballs in `asyncio.create_connection()`. Added two new | ||
arguments, *happy_eyeballs_delay* and *interleave*, | ||
to specify Happy Eyeballs behavior. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe it's worth naming the parameter
interleave_af
to make it more explicit.