diff --git a/gql/transport/aiohttp.py b/gql/transport/aiohttp.py index cdc4571b..780027b8 100644 --- a/gql/transport/aiohttp.py +++ b/gql/transport/aiohttp.py @@ -206,35 +206,36 @@ async def execute( raise TransportClosed("Transport is not connected") async with self.session.post(self.url, ssl=self.ssl, **post_args) as resp: - try: - result = await resp.json() - if log.isEnabledFor(logging.INFO): - result_text = await resp.text() - log.info("<<< %s", result_text) - except Exception: + async def raise_response_error(resp: aiohttp.ClientResponse, reason: str): # We raise a TransportServerError if the status code is 400 or higher # We raise a TransportProtocolError in the other cases try: # Raise a ClientResponseError if response status is 400 or higher resp.raise_for_status() - except ClientResponseError as e: - raise TransportServerError(str(e)) from e + raise TransportServerError(str(e), e.status) from e result_text = await resp.text() raise TransportProtocolError( - f"Server did not return a GraphQL result: {result_text}" + f"Server did not return a GraphQL result: " + f"{reason}: " + f"{result_text}" ) + try: + result = await resp.json() + + if log.isEnabledFor(logging.INFO): + result_text = await resp.text() + log.info("<<< %s", result_text) + + except Exception: + await raise_response_error(resp, "Not a JSON answer") + if "errors" not in result and "data" not in result: - result_text = await resp.text() - raise TransportProtocolError( - "Server did not return a GraphQL result: " - 'No "data" or "error" keys in answer: ' - f"{result_text}" - ) + await raise_response_error(resp, 'No "data" or "errors" keys in answer') return ExecutionResult( errors=result.get("errors"), diff --git a/gql/transport/exceptions.py b/gql/transport/exceptions.py index 4df2ec43..899d5d66 100644 --- a/gql/transport/exceptions.py +++ b/gql/transport/exceptions.py @@ -18,6 +18,10 @@ class TransportServerError(TransportError): This exception will close the transport connection. """ + def __init__(self, message=None, code=None): + super(TransportServerError, self).__init__(message) + self.code = code + class TransportQueryError(Exception): """The server returned an error for a specific query. diff --git a/gql/transport/requests.py b/gql/transport/requests.py index c7d03adb..d0bc1467 100644 --- a/gql/transport/requests.py +++ b/gql/transport/requests.py @@ -38,7 +38,7 @@ def __init__( verify: bool = True, retries: int = 0, method: str = "POST", - **kwargs: Any + **kwargs: Any, ): """Initialize the transport with the given request parameters. @@ -150,26 +150,35 @@ def execute( # type: ignore response = self.session.request( self.method, self.url, **post_args # type: ignore ) - try: - result = response.json() - if log.isEnabledFor(logging.INFO): - log.info("<<< %s", response.text) - except Exception: + def raise_response_error(resp: requests.Response, reason: str): # We raise a TransportServerError if the status code is 400 or higher # We raise a TransportProtocolError in the other cases try: - # Raise a requests.HTTPerror if response status is 400 or higher - response.raise_for_status() - + # Raise a HTTPError if response status is 400 or higher + resp.raise_for_status() except requests.HTTPError as e: - raise TransportServerError(str(e)) + raise TransportServerError(str(e), e.response.status_code) from e + + result_text = resp.text + raise TransportProtocolError( + f"Server did not return a GraphQL result: " + f"{reason}: " + f"{result_text}" + ) - raise TransportProtocolError("Server did not return a GraphQL result") + try: + result = response.json() + + if log.isEnabledFor(logging.INFO): + log.info("<<< %s", response.text) + + except Exception: + raise_response_error(response, "Not a JSON answer") if "errors" not in result and "data" not in result: - raise TransportProtocolError("Server did not return a GraphQL result") + raise_response_error(response, 'No "data" or "errors" keys in answer') return ExecutionResult( errors=result.get("errors"), diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py index 5c135b29..3fb85cd0 100644 --- a/tests/test_aiohttp.py +++ b/tests/test_aiohttp.py @@ -102,6 +102,37 @@ async def handler(request): assert africa["code"] == "AF" +@pytest.mark.asyncio +async def test_aiohttp_error_code_401(event_loop, aiohttp_server): + from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport + + async def handler(request): + # Will generate http error code 401 + return web.Response( + text='{"error":"Unauthorized","message":"401 Client Error: Unauthorized"}', + content_type="application/json", + status=401, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + sample_transport = AIOHTTPTransport(url=url) + + async with Client(transport=sample_transport,) as session: + + query = gql(query1_str) + + with pytest.raises(TransportServerError) as exc_info: + await session.execute(query) + + assert "401, message='Unauthorized'" in str(exc_info.value) + + @pytest.mark.asyncio async def test_aiohttp_error_code_500(event_loop, aiohttp_server): from aiohttp import web @@ -163,20 +194,20 @@ async def handler(request): "response": "{}", "expected_exception": ( "Server did not return a GraphQL result: " - 'No "data" or "error" keys in answer: {}' + 'No "data" or "errors" keys in answer: {}' ), }, { "response": "qlsjfqsdlkj", "expected_exception": ( - "Server did not return a GraphQL result: " "qlsjfqsdlkj" + "Server did not return a GraphQL result: Not a JSON answer: qlsjfqsdlkj" ), }, { "response": '{"not_data_or_errors": 35}', "expected_exception": ( "Server did not return a GraphQL result: " - 'No "data" or "error" keys in answer: {"not_data_or_errors": 35}' + 'No "data" or "errors" keys in answer: {"not_data_or_errors": 35}' ), }, ] diff --git a/tests/test_client.py b/tests/test_client.py index f2a7ecf8..1521eac7 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -7,6 +7,7 @@ from gql import Client, gql from gql.transport import Transport +from gql.transport.exceptions import TransportQueryError with suppress(ModuleNotFoundError): from urllib3.exceptions import NewConnectionError @@ -105,7 +106,7 @@ def test_execute_result_error(): """ ) - with pytest.raises(Exception) as exc_info: + with pytest.raises(TransportQueryError) as exc_info: client.execute(failing_query) assert 'Cannot query field "id" on type "Continent".' in str(exc_info.value) diff --git a/tests/test_requests.py b/tests/test_requests.py index 2afbd84a..e18875a2 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -101,6 +101,41 @@ def test_code(): await run_sync_test(event_loop, server, test_code) +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_requests_error_code_401(event_loop, aiohttp_server, run_sync_test): + from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport + + async def handler(request): + # Will generate http error code 401 + return web.Response( + text='{"error":"Unauthorized","message":"401 Client Error: Unauthorized"}', + content_type="application/json", + status=401, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + def test_code(): + sample_transport = RequestsHTTPTransport(url=url) + + with Client(transport=sample_transport,) as session: + + query = gql(query1_str) + + with pytest.raises(TransportServerError) as exc_info: + session.execute(query) + + assert "401 Client Error: Unauthorized" in str(exc_info.value) + + await run_sync_test(event_loop, server, test_code) + + @pytest.mark.aiohttp @pytest.mark.asyncio async def test_requests_error_code_500(event_loop, aiohttp_server, run_sync_test):