From b4facb07ac2e78696c6dd2c61381b3c021c788ea Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Wed, 14 Oct 2020 23:35:01 +0200 Subject: [PATCH 1/3] Fix aiohttp wait for closed connections --- gql/transport/aiohttp.py | 57 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/gql/transport/aiohttp.py b/gql/transport/aiohttp.py index 2ae83999..bed4b54c 100644 --- a/gql/transport/aiohttp.py +++ b/gql/transport/aiohttp.py @@ -1,3 +1,5 @@ +import asyncio +import functools from ssl import SSLContext from typing import Any, AsyncGenerator, Dict, Optional, Union @@ -88,6 +90,59 @@ async def connect(self) -> None: else: raise TransportAlreadyConnected("Transport is already connected") + @staticmethod + def create_aiohttp_closed_event(session) -> asyncio.Event: + """Work around aiohttp issue that doesn't properly close transports on exit. + + See https://github.com/aio-libs/aiohttp/issues/1925#issuecomment-639080209 + + Returns: + An event that will be set once all transports have been properly closed. + """ + + transports = 0 + all_is_lost = asyncio.Event() + + def connection_lost(exc, orig_lost): + nonlocal transports + + try: + orig_lost(exc) + finally: + transports -= 1 + if transports == 0: + all_is_lost.set() + + def eof_received(orig_eof_received): + try: + orig_eof_received() + except AttributeError: + # It may happen that eof_received() is called after + # _app_protocol and _transport are set to None. + pass + + for conn in session.connector._conns.values(): + for handler, _ in conn: + proto = getattr(handler.transport, "_ssl_protocol", None) + if proto is None: + continue + + transports += 1 + orig_lost = proto.connection_lost + orig_eof_received = proto.eof_received + + proto.connection_lost = functools.partial( + connection_lost, orig_lost=orig_lost + ) + proto.eof_received = functools.partial( + eof_received, orig_eof_received=orig_eof_received + ) + + if transports == 0: + all_is_lost.set() + + return all_is_lost + async def close(self) -> None: """Coroutine which will close the aiohttp session. @@ -96,7 +151,9 @@ async def close(self) -> None: when you exit the async context manager. """ if self.session is not None: + closed_event = self.create_aiohttp_closed_event(self.session) await self.session.close() + await closed_event.wait() self.session = None async def execute( From 74d6589e6662cb459e27fb9ee06bcfb4df839fd7 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Fri, 26 Nov 2021 17:31:20 +0100 Subject: [PATCH 2/3] Adding aiohttp test with ssl --- gql/transport/aiohttp.py | 14 +++++------ tests/conftest.py | 54 +++++++++++++++++++++++++++------------- tests/test_aiohttp.py | 32 ++++++++++++++++++++++++ 3 files changed, 76 insertions(+), 24 deletions(-) diff --git a/gql/transport/aiohttp.py b/gql/transport/aiohttp.py index 31914f89..c0b6263b 100644 --- a/gql/transport/aiohttp.py +++ b/gql/transport/aiohttp.py @@ -112,23 +112,23 @@ def create_aiohttp_closed_event(session) -> asyncio.Event: An event that will be set once all transports have been properly closed. """ - transports = 0 + ssl_transports = 0 all_is_lost = asyncio.Event() def connection_lost(exc, orig_lost): - nonlocal transports + nonlocal ssl_transports try: orig_lost(exc) finally: - transports -= 1 - if transports == 0: + ssl_transports -= 1 + if ssl_transports == 0: all_is_lost.set() def eof_received(orig_eof_received): try: orig_eof_received() - except AttributeError: + except AttributeError: # pragma: no cover # It may happen that eof_received() is called after # _app_protocol and _transport are set to None. pass @@ -139,7 +139,7 @@ def eof_received(orig_eof_received): if proto is None: continue - transports += 1 + ssl_transports += 1 orig_lost = proto.connection_lost orig_eof_received = proto.eof_received @@ -150,7 +150,7 @@ def eof_received(orig_eof_received): eof_received, orig_eof_received=orig_eof_received ) - if transports == 0: + if ssl_transports == 0: all_is_lost.set() return all_is_lost diff --git a/tests/conftest.py b/tests/conftest.py index 6fd9fc44..c0101241 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -77,8 +77,7 @@ def pytest_collection_modifyitems(config, items): item.add_marker(skip_transport) -@pytest.fixture -async def aiohttp_server(): +async def aiohttp_server_base(with_ssl=False): """Factory to create a TestServer instance, given an app. aiohttp_server(app, **kwargs) @@ -89,7 +88,13 @@ async def aiohttp_server(): async def go(app, *, port=None, **kwargs): # type: ignore server = AIOHTTPTestServer(app, port=port) - await server.start_server(**kwargs) + + start_server_args = {**kwargs} + if with_ssl: + testcert, ssl_context = get_localhost_ssl_context() + start_server_args["ssl"] = ssl_context + + await server.start_server(**start_server_args) servers.append(server) return server @@ -99,6 +104,18 @@ async def go(app, *, port=None, **kwargs): # type: ignore await servers.pop().close() +@pytest.fixture +async def aiohttp_server(): + async for server in aiohttp_server_base(): + yield server + + +@pytest.fixture +async def ssl_aiohttp_server(): + async for server in aiohttp_server_base(with_ssl=True): + yield server + + # Adding debug logs to websocket tests for name in [ "websockets.legacy.server", @@ -121,6 +138,22 @@ async def go(app, *, port=None, **kwargs): # type: ignore MS = 0.001 * int(os.environ.get("GQL_TESTS_TIMEOUT_FACTOR", 1)) +def get_localhost_ssl_context(): + # This is a copy of certificate from websockets tests folder + # + # Generate TLS certificate with: + # $ openssl req -x509 -config test_localhost.cnf \ + # -days 15340 -newkey rsa:2048 \ + # -out test_localhost.crt -keyout test_localhost.key + # $ cat test_localhost.key test_localhost.crt > test_localhost.pem + # $ rm test_localhost.key test_localhost.crt + testcert = bytes(pathlib.Path(__file__).with_name("test_localhost.pem")) + ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + ssl_context.load_cert_chain(testcert) + + return (testcert, ssl_context) + + class WebSocketServer: """Websocket server on localhost on a free port. @@ -141,20 +174,7 @@ async def start(self, handler, extra_serve_args=None): extra_serve_args = {} if self.with_ssl: - # This is a copy of certificate from websockets tests folder - # - # Generate TLS certificate with: - # $ openssl req -x509 -config test_localhost.cnf \ - # -days 15340 -newkey rsa:2048 \ - # -out test_localhost.crt -keyout test_localhost.key - # $ cat test_localhost.key test_localhost.crt > test_localhost.pem - # $ rm test_localhost.key test_localhost.crt - self.testcert = bytes( - pathlib.Path(__file__).with_name("test_localhost.pem") - ) - ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) - ssl_context.load_cert_chain(self.testcert) - + self.testcert, ssl_context = get_localhost_ssl_context() extra_serve_args["ssl"] = ssl_context # Start a server with a random open port diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py index 50cec3f9..1c7b0526 100644 --- a/tests/test_aiohttp.py +++ b/tests/test_aiohttp.py @@ -1073,3 +1073,35 @@ async def handler(request): execution_result = await session.execute(query, get_execution_result=True) assert execution_result.extensions["key1"] == "val1" + + +@pytest.mark.asyncio +async def test_aiohttp_query_https(event_loop, ssl_aiohttp_server): + from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport + + async def handler(request): + return web.Response(text=query1_server_answer, content_type="application/json") + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await ssl_aiohttp_server(app) + + url = server.make_url("/") + + assert str(url).startswith("https://") + + sample_transport = AIOHTTPTransport(url=url, timeout=10) + + async with Client(transport=sample_transport,) as session: + + query = gql(query1_str) + + # Execute query asynchronously + result = await session.execute(query) + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" From c55d37b11e5bf9942c9e57d32a288dd6ba523b40 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Fri, 26 Nov 2021 17:59:00 +0100 Subject: [PATCH 3/3] Add new ssl_close_timeout argument to AIOHTTPTransport --- gql/transport/aiohttp.py | 9 ++++++++- tests/test_aiohttp.py | 7 +++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/gql/transport/aiohttp.py b/gql/transport/aiohttp.py index c0b6263b..f34a0066 100644 --- a/gql/transport/aiohttp.py +++ b/gql/transport/aiohttp.py @@ -46,6 +46,7 @@ def __init__( auth: Optional[BasicAuth] = None, ssl: Union[SSLContext, bool, Fingerprint] = False, timeout: Optional[int] = None, + ssl_close_timeout: Optional[Union[int, float]] = 10, client_session_args: Optional[Dict[str, Any]] = None, ) -> None: """Initialize the transport with the given aiohttp parameters. @@ -55,6 +56,8 @@ def __init__( :param cookies: Dict of HTTP cookies. :param auth: BasicAuth object to enable Basic HTTP auth if needed :param ssl: ssl_context of the connection. Use ssl=False to disable encryption + :param ssl_close_timeout: Timeout in seconds to wait for the ssl connection + to close properly :param client_session_args: Dict of extra args passed to `aiohttp.ClientSession`_ @@ -67,6 +70,7 @@ def __init__( self.auth: Optional[BasicAuth] = auth self.ssl: Union[SSLContext, bool, Fingerprint] = ssl self.timeout: Optional[int] = timeout + self.ssl_close_timeout: Optional[Union[int, float]] = ssl_close_timeout self.client_session_args = client_session_args self.session: Optional[aiohttp.ClientSession] = None @@ -165,7 +169,10 @@ async def close(self) -> None: if self.session is not None: closed_event = self.create_aiohttp_closed_event(self.session) await self.session.close() - await closed_event.wait() + try: + await asyncio.wait_for(closed_event.wait(), self.ssl_close_timeout) + except asyncio.TimeoutError: + pass self.session = None async def execute( diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py index 1c7b0526..6dbe46ae 100644 --- a/tests/test_aiohttp.py +++ b/tests/test_aiohttp.py @@ -1076,7 +1076,8 @@ async def handler(request): @pytest.mark.asyncio -async def test_aiohttp_query_https(event_loop, ssl_aiohttp_server): +@pytest.mark.parametrize("ssl_close_timeout", [0, 10]) +async def test_aiohttp_query_https(event_loop, ssl_aiohttp_server, ssl_close_timeout): from aiohttp import web from gql.transport.aiohttp import AIOHTTPTransport @@ -1091,7 +1092,9 @@ async def handler(request): assert str(url).startswith("https://") - sample_transport = AIOHTTPTransport(url=url, timeout=10) + sample_transport = AIOHTTPTransport( + url=url, timeout=10, ssl_close_timeout=ssl_close_timeout + ) async with Client(transport=sample_transport,) as session: