Skip to content

Commit e74a4a0

Browse files
[PR #8163/006fbe03 backport][3.9] Avoid creating a task to do DNS resolution if there is no throttle (#8172)
Co-authored-by: J. Nick Koston <[email protected]> Fixes #123'). -->
1 parent 87e0697 commit e74a4a0

File tree

3 files changed

+47
-14
lines changed

3 files changed

+47
-14
lines changed

CHANGES/8163.bugfix.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
Improved the DNS resolution performance on cache hit
2+
-- by :user:`bdraco`.
3+
4+
This is achieved by avoiding an :mod:`asyncio` task creation
5+
in this case.

aiohttp/connector.py

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -825,6 +825,7 @@ def clear_dns_cache(
825825
async def _resolve_host(
826826
self, host: str, port: int, traces: Optional[List["Trace"]] = None
827827
) -> List[Dict[str, Any]]:
828+
"""Resolve host and return list of addresses."""
828829
if is_ip_address(host):
829830
return [
830831
{
@@ -852,8 +853,7 @@ async def _resolve_host(
852853
return res
853854

854855
key = (host, port)
855-
856-
if (key in self._cached_hosts) and (not self._cached_hosts.expired(key)):
856+
if key in self._cached_hosts and not self._cached_hosts.expired(key):
857857
# get result early, before any await (#4014)
858858
result = self._cached_hosts.next_addrs(key)
859859

@@ -862,6 +862,39 @@ async def _resolve_host(
862862
await trace.send_dns_cache_hit(host)
863863
return result
864864

865+
#
866+
# If multiple connectors are resolving the same host, we wait
867+
# for the first one to resolve and then use the result for all of them.
868+
# We use a throttle event to ensure that we only resolve the host once
869+
# and then use the result for all the waiters.
870+
#
871+
# In this case we need to create a task to ensure that we can shield
872+
# the task from cancellation as cancelling this lookup should not cancel
873+
# the underlying lookup or else the cancel event will get broadcast to
874+
# all the waiters across all connections.
875+
#
876+
resolved_host_task = asyncio.create_task(
877+
self._resolve_host_with_throttle(key, host, port, traces)
878+
)
879+
try:
880+
return await asyncio.shield(resolved_host_task)
881+
except asyncio.CancelledError:
882+
883+
def drop_exception(fut: "asyncio.Future[List[Dict[str, Any]]]") -> None:
884+
with suppress(Exception, asyncio.CancelledError):
885+
fut.result()
886+
887+
resolved_host_task.add_done_callback(drop_exception)
888+
raise
889+
890+
async def _resolve_host_with_throttle(
891+
self,
892+
key: Tuple[str, int],
893+
host: str,
894+
port: int,
895+
traces: Optional[List["Trace"]],
896+
) -> List[Dict[str, Any]]:
897+
"""Resolve host with a dns events throttle."""
865898
if key in self._throttle_dns_events:
866899
# get event early, before any await (#4014)
867900
event = self._throttle_dns_events[key]
@@ -1163,22 +1196,11 @@ async def _create_direct_connection(
11631196
host = host.rstrip(".") + "."
11641197
port = req.port
11651198
assert port is not None
1166-
host_resolved = asyncio.ensure_future(
1167-
self._resolve_host(host, port, traces=traces), loop=self._loop
1168-
)
11691199
try:
11701200
# Cancelling this lookup should not cancel the underlying lookup
11711201
# or else the cancel event will get broadcast to all the waiters
11721202
# across all connections.
1173-
hosts = await asyncio.shield(host_resolved)
1174-
except asyncio.CancelledError:
1175-
1176-
def drop_exception(fut: "asyncio.Future[List[Dict[str, Any]]]") -> None:
1177-
with suppress(Exception, asyncio.CancelledError):
1178-
fut.result()
1179-
1180-
host_resolved.add_done_callback(drop_exception)
1181-
raise
1203+
hosts = await self._resolve_host(host, port, traces=traces)
11821204
except OSError as exc:
11831205
if exc.errno is None and isinstance(exc, asyncio.TimeoutError):
11841206
raise

tests/test_connector.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -767,6 +767,7 @@ async def test_tcp_connector_dns_throttle_requests(loop, dns_response) -> None:
767767
loop.create_task(conn._resolve_host("localhost", 8080))
768768
loop.create_task(conn._resolve_host("localhost", 8080))
769769
await asyncio.sleep(0)
770+
await asyncio.sleep(0)
770771
m_resolver().resolve.assert_called_once_with("localhost", 8080, family=0)
771772

772773

@@ -778,6 +779,9 @@ async def test_tcp_connector_dns_throttle_requests_exception_spread(loop) -> Non
778779
r1 = loop.create_task(conn._resolve_host("localhost", 8080))
779780
r2 = loop.create_task(conn._resolve_host("localhost", 8080))
780781
await asyncio.sleep(0)
782+
await asyncio.sleep(0)
783+
await asyncio.sleep(0)
784+
await asyncio.sleep(0)
781785
assert r1.exception() == e
782786
assert r2.exception() == e
783787

@@ -792,6 +796,7 @@ async def test_tcp_connector_dns_throttle_requests_cancelled_when_close(
792796
loop.create_task(conn._resolve_host("localhost", 8080))
793797
f = loop.create_task(conn._resolve_host("localhost", 8080))
794798

799+
await asyncio.sleep(0)
795800
await asyncio.sleep(0)
796801
await conn.close()
797802

@@ -956,6 +961,7 @@ async def test_tcp_connector_dns_tracing_throttle_requests(loop, dns_response) -
956961
loop.create_task(conn._resolve_host("localhost", 8080, traces=traces))
957962
loop.create_task(conn._resolve_host("localhost", 8080, traces=traces))
958963
await asyncio.sleep(0)
964+
await asyncio.sleep(0)
959965
on_dns_cache_hit.assert_called_once_with(
960966
session, trace_config_ctx, aiohttp.TraceDnsCacheHitParams("localhost")
961967
)

0 commit comments

Comments
 (0)