diff --git a/gql/transport/aiohttp.py b/gql/transport/aiohttp.py index 090463e9..f34a0066 100644 --- a/gql/transport/aiohttp.py +++ b/gql/transport/aiohttp.py @@ -1,3 +1,5 @@ +import asyncio +import functools import io import json import logging @@ -44,6 +46,7 @@ def __init__( auth: Optional[BasicAuth] = None, ssl: Union[SSLContext, bool, Fingerprint] = False, timeout: Optional[int] = None, + ssl_close_timeout: Optional[Union[int, float]] = 10, client_session_args: Optional[Dict[str, Any]] = None, ) -> None: """Initialize the transport with the given aiohttp parameters. @@ -53,6 +56,8 @@ def __init__( :param cookies: Dict of HTTP cookies. :param auth: BasicAuth object to enable Basic HTTP auth if needed :param ssl: ssl_context of the connection. Use ssl=False to disable encryption + :param ssl_close_timeout: Timeout in seconds to wait for the ssl connection + to close properly :param client_session_args: Dict of extra args passed to `aiohttp.ClientSession`_ @@ -65,6 +70,7 @@ def __init__( self.auth: Optional[BasicAuth] = auth self.ssl: Union[SSLContext, bool, Fingerprint] = ssl self.timeout: Optional[int] = timeout + self.ssl_close_timeout: Optional[Union[int, float]] = ssl_close_timeout self.client_session_args = client_session_args self.session: Optional[aiohttp.ClientSession] = None @@ -100,6 +106,59 @@ async def connect(self) -> None: else: raise TransportAlreadyConnected("Transport is already connected") + @staticmethod + def create_aiohttp_closed_event(session) -> asyncio.Event: + """Work around aiohttp issue that doesn't properly close transports on exit. + + See https://github.com/aio-libs/aiohttp/issues/1925#issuecomment-639080209 + + Returns: + An event that will be set once all transports have been properly closed. + """ + + ssl_transports = 0 + all_is_lost = asyncio.Event() + + def connection_lost(exc, orig_lost): + nonlocal ssl_transports + + try: + orig_lost(exc) + finally: + ssl_transports -= 1 + if ssl_transports == 0: + all_is_lost.set() + + def eof_received(orig_eof_received): + try: + orig_eof_received() + except AttributeError: # pragma: no cover + # It may happen that eof_received() is called after + # _app_protocol and _transport are set to None. + pass + + for conn in session.connector._conns.values(): + for handler, _ in conn: + proto = getattr(handler.transport, "_ssl_protocol", None) + if proto is None: + continue + + ssl_transports += 1 + orig_lost = proto.connection_lost + orig_eof_received = proto.eof_received + + proto.connection_lost = functools.partial( + connection_lost, orig_lost=orig_lost + ) + proto.eof_received = functools.partial( + eof_received, orig_eof_received=orig_eof_received + ) + + if ssl_transports == 0: + all_is_lost.set() + + return all_is_lost + async def close(self) -> None: """Coroutine which will close the aiohttp session. @@ -108,7 +167,12 @@ async def close(self) -> None: when you exit the async context manager. """ if self.session is not None: + closed_event = self.create_aiohttp_closed_event(self.session) await self.session.close() + try: + await asyncio.wait_for(closed_event.wait(), self.ssl_close_timeout) + except asyncio.TimeoutError: + pass self.session = None async def execute( diff --git a/tests/conftest.py b/tests/conftest.py index 6fd9fc44..c0101241 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -77,8 +77,7 @@ def pytest_collection_modifyitems(config, items): item.add_marker(skip_transport) -@pytest.fixture -async def aiohttp_server(): +async def aiohttp_server_base(with_ssl=False): """Factory to create a TestServer instance, given an app. aiohttp_server(app, **kwargs) @@ -89,7 +88,13 @@ async def aiohttp_server(): async def go(app, *, port=None, **kwargs): # type: ignore server = AIOHTTPTestServer(app, port=port) - await server.start_server(**kwargs) + + start_server_args = {**kwargs} + if with_ssl: + testcert, ssl_context = get_localhost_ssl_context() + start_server_args["ssl"] = ssl_context + + await server.start_server(**start_server_args) servers.append(server) return server @@ -99,6 +104,18 @@ async def go(app, *, port=None, **kwargs): # type: ignore await servers.pop().close() +@pytest.fixture +async def aiohttp_server(): + async for server in aiohttp_server_base(): + yield server + + +@pytest.fixture +async def ssl_aiohttp_server(): + async for server in aiohttp_server_base(with_ssl=True): + yield server + + # Adding debug logs to websocket tests for name in [ "websockets.legacy.server", @@ -121,6 +138,22 @@ async def go(app, *, port=None, **kwargs): # type: ignore MS = 0.001 * int(os.environ.get("GQL_TESTS_TIMEOUT_FACTOR", 1)) +def get_localhost_ssl_context(): + # This is a copy of certificate from websockets tests folder + # + # Generate TLS certificate with: + # $ openssl req -x509 -config test_localhost.cnf \ + # -days 15340 -newkey rsa:2048 \ + # -out test_localhost.crt -keyout test_localhost.key + # $ cat test_localhost.key test_localhost.crt > test_localhost.pem + # $ rm test_localhost.key test_localhost.crt + testcert = bytes(pathlib.Path(__file__).with_name("test_localhost.pem")) + ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + ssl_context.load_cert_chain(testcert) + + return (testcert, ssl_context) + + class WebSocketServer: """Websocket server on localhost on a free port. @@ -141,20 +174,7 @@ async def start(self, handler, extra_serve_args=None): extra_serve_args = {} if self.with_ssl: - # This is a copy of certificate from websockets tests folder - # - # Generate TLS certificate with: - # $ openssl req -x509 -config test_localhost.cnf \ - # -days 15340 -newkey rsa:2048 \ - # -out test_localhost.crt -keyout test_localhost.key - # $ cat test_localhost.key test_localhost.crt > test_localhost.pem - # $ rm test_localhost.key test_localhost.crt - self.testcert = bytes( - pathlib.Path(__file__).with_name("test_localhost.pem") - ) - ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) - ssl_context.load_cert_chain(self.testcert) - + self.testcert, ssl_context = get_localhost_ssl_context() extra_serve_args["ssl"] = ssl_context # Start a server with a random open port diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py index 50cec3f9..6dbe46ae 100644 --- a/tests/test_aiohttp.py +++ b/tests/test_aiohttp.py @@ -1073,3 +1073,38 @@ async def handler(request): execution_result = await session.execute(query, get_execution_result=True) assert execution_result.extensions["key1"] == "val1" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("ssl_close_timeout", [0, 10]) +async def test_aiohttp_query_https(event_loop, ssl_aiohttp_server, ssl_close_timeout): + from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport + + async def handler(request): + return web.Response(text=query1_server_answer, content_type="application/json") + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await ssl_aiohttp_server(app) + + url = server.make_url("/") + + assert str(url).startswith("https://") + + sample_transport = AIOHTTPTransport( + url=url, timeout=10, ssl_close_timeout=ssl_close_timeout + ) + + async with Client(transport=sample_transport,) as session: + + query = gql(query1_str) + + # Execute query asynchronously + result = await session.execute(query) + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF"