From 55b60da58d7345fb78506438f193f995883ec8d4 Mon Sep 17 00:00:00 2001 From: Johannes Richter Date: Fri, 18 Feb 2022 11:12:07 +0100 Subject: [PATCH 1/5] Close transport when fetching the schema failed Fetching the schema from the transport might fail with an exception. This Commit ensures that the transport is closed in such a case and the client object can be used to open a new session. --- gql/client.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/gql/client.py b/gql/client.py index 91cbcde6..64dd2469 100644 --- a/gql/client.py +++ b/gql/client.py @@ -271,8 +271,12 @@ async def __aenter__(self): self.session = AsyncClientSession(client=self) # Get schema from transport if needed - if self.fetch_schema_from_transport and not self.schema: - await self.session.fetch_schema() + try: + if self.fetch_schema_from_transport and not self.schema: + await self.session.fetch_schema() + except Exception: + await self.transport.close() + raise return self.session @@ -293,8 +297,12 @@ def __enter__(self): self.session = SyncClientSession(client=self) # Get schema from transport if needed - if self.fetch_schema_from_transport and not self.schema: - self.session.fetch_schema() + try: + if self.fetch_schema_from_transport and not self.schema: + self.session.fetch_schema() + except Exception: + await self.transport.close() + raise return self.session From 34078a8ae1db4e318a7506770d5894638e83f3c1 Mon Sep 17 00:00:00 2001 From: Johannes Richter Date: Mon, 21 Feb 2022 16:13:28 +0100 Subject: [PATCH 2/5] Remove await in sync enter function --- gql/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gql/client.py b/gql/client.py index 64dd2469..e8ef3b10 100644 --- a/gql/client.py +++ b/gql/client.py @@ -301,7 +301,7 @@ def __enter__(self): if self.fetch_schema_from_transport and not self.schema: self.session.fetch_schema() except Exception: - await self.transport.close() + self.transport.close() raise return self.session From 468193dd384c14a1009841ca310e7e132712ff5b Mon Sep 17 00:00:00 2001 From: Johannes Richter Date: Mon, 21 Feb 2022 16:14:29 +0100 Subject: [PATCH 3/5] Add tests --- tests/test_client.py | 44 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/tests/test_client.py b/tests/test_client.py index c8df40ee..50b3ba30 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -200,3 +200,47 @@ def test_gql(): client = Client(schema=schema) result = client.execute(query) assert result["user"] is None + + +def test_sync_transport_close_on_schema_retrieval_failure(): + """ + Ensure that the transport session is closed if an error occurs when + entering the context manager (e.g., because schema retrieval fails) + """ + + from gql.transport.requests import RequestsHTTPTransport + transport = RequestsHTTPTransport(url="http://localhost/") + client = Client(transport=transport, fetch_schema_from_transport=True) + + try: + with client as session: + pass + except Exception: + # we don't care what exception is thrown, we just want to check if the + # transport is closed afterwards + pass + + assert client.transport.session is None + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_async_transport_close_on_schema_retrieval_failure(): + """ + Ensure that the transport session is closed if an error occurs when + entering the context manager (e.g., because schema retrieval fails) + """ + + from gql.transport.aiohttp import AIOHTTPTransport + transport = AIOHTTPTransport(url="http://localhost/") + client = Client(transport=transport, fetch_schema_from_transport=True) + + try: + async with client as session: + pass + except Exception: + # we don't care what exception is thrown, we just want to check if the + # transport is closed afterwards + pass + + assert client.transport.session is None From 78f954b9cd735e632d52ae18c2203ac12c187c0a Mon Sep 17 00:00:00 2001 From: Johannes Richter Date: Mon, 21 Feb 2022 16:22:27 +0100 Subject: [PATCH 4/5] Add comment on why we use broad exception --- gql/client.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/gql/client.py b/gql/client.py index e8ef3b10..5203d17d 100644 --- a/gql/client.py +++ b/gql/client.py @@ -275,6 +275,9 @@ async def __aenter__(self): if self.fetch_schema_from_transport and not self.schema: await self.session.fetch_schema() except Exception: + # we don't know what type of exception is thrown here because it + # depends on the underlying transport; we just make sure that the + # transport is closed and re-raise the exception await self.transport.close() raise @@ -301,6 +304,9 @@ def __enter__(self): if self.fetch_schema_from_transport and not self.schema: self.session.fetch_schema() except Exception: + # we don't know what type of exception is thrown here because it + # depends on the underlying transport; we just make sure that the + # transport is closed and re-raise the exception self.transport.close() raise From 18fead44653729ab946621703a23b7b199ee7804 Mon Sep 17 00:00:00 2001 From: Johannes Richter Date: Tue, 22 Feb 2022 07:50:05 +0100 Subject: [PATCH 5/5] Respect coding guidelines --- tests/test_client.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/test_client.py b/tests/test_client.py index 50b3ba30..fecdf43d 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -202,6 +202,7 @@ def test_gql(): assert result["user"] is None +@pytest.mark.requests def test_sync_transport_close_on_schema_retrieval_failure(): """ Ensure that the transport session is closed if an error occurs when @@ -209,11 +210,12 @@ def test_sync_transport_close_on_schema_retrieval_failure(): """ from gql.transport.requests import RequestsHTTPTransport + transport = RequestsHTTPTransport(url="http://localhost/") client = Client(transport=transport, fetch_schema_from_transport=True) try: - with client as session: + with client: pass except Exception: # we don't care what exception is thrown, we just want to check if the @@ -232,11 +234,12 @@ async def test_async_transport_close_on_schema_retrieval_failure(): """ from gql.transport.aiohttp import AIOHTTPTransport + transport = AIOHTTPTransport(url="http://localhost/") client = Client(transport=transport, fetch_schema_from_transport=True) try: - async with client as session: + async with client: pass except Exception: # we don't care what exception is thrown, we just want to check if the