Skip to content

Commit 73d0ba5

Browse files
Fix 4xx error handling in transports (#195)
1 parent c4f2dc2 commit 73d0ba5

File tree

6 files changed

+112
-31
lines changed

6 files changed

+112
-31
lines changed

gql/transport/aiohttp.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -206,35 +206,36 @@ async def execute(
206206
raise TransportClosed("Transport is not connected")
207207

208208
async with self.session.post(self.url, ssl=self.ssl, **post_args) as resp:
209-
try:
210-
result = await resp.json()
211209

212-
if log.isEnabledFor(logging.INFO):
213-
result_text = await resp.text()
214-
log.info("<<< %s", result_text)
215-
except Exception:
210+
async def raise_response_error(resp: aiohttp.ClientResponse, reason: str):
216211
# We raise a TransportServerError if the status code is 400 or higher
217212
# We raise a TransportProtocolError in the other cases
218213

219214
try:
220215
# Raise a ClientResponseError if response status is 400 or higher
221216
resp.raise_for_status()
222-
223217
except ClientResponseError as e:
224-
raise TransportServerError(str(e)) from e
218+
raise TransportServerError(str(e), e.status) from e
225219

226220
result_text = await resp.text()
227221
raise TransportProtocolError(
228-
f"Server did not return a GraphQL result: {result_text}"
222+
f"Server did not return a GraphQL result: "
223+
f"{reason}: "
224+
f"{result_text}"
229225
)
230226

227+
try:
228+
result = await resp.json()
229+
230+
if log.isEnabledFor(logging.INFO):
231+
result_text = await resp.text()
232+
log.info("<<< %s", result_text)
233+
234+
except Exception:
235+
await raise_response_error(resp, "Not a JSON answer")
236+
231237
if "errors" not in result and "data" not in result:
232-
result_text = await resp.text()
233-
raise TransportProtocolError(
234-
"Server did not return a GraphQL result: "
235-
'No "data" or "error" keys in answer: '
236-
f"{result_text}"
237-
)
238+
await raise_response_error(resp, 'No "data" or "errors" keys in answer')
238239

239240
return ExecutionResult(
240241
errors=result.get("errors"),

gql/transport/exceptions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ class TransportServerError(TransportError):
1818
This exception will close the transport connection.
1919
"""
2020

21+
def __init__(self, message=None, code=None):
22+
super(TransportServerError, self).__init__(message)
23+
self.code = code
24+
2125

2226
class TransportQueryError(Exception):
2327
"""The server returned an error for a specific query.

gql/transport/requests.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def __init__(
3838
verify: bool = True,
3939
retries: int = 0,
4040
method: str = "POST",
41-
**kwargs: Any
41+
**kwargs: Any,
4242
):
4343
"""Initialize the transport with the given request parameters.
4444
@@ -150,26 +150,35 @@ def execute( # type: ignore
150150
response = self.session.request(
151151
self.method, self.url, **post_args # type: ignore
152152
)
153-
try:
154-
result = response.json()
155153

156-
if log.isEnabledFor(logging.INFO):
157-
log.info("<<< %s", response.text)
158-
except Exception:
154+
def raise_response_error(resp: requests.Response, reason: str):
159155
# We raise a TransportServerError if the status code is 400 or higher
160156
# We raise a TransportProtocolError in the other cases
161157

162158
try:
163-
# Raise a requests.HTTPerror if response status is 400 or higher
164-
response.raise_for_status()
165-
159+
# Raise a HTTPError if response status is 400 or higher
160+
resp.raise_for_status()
166161
except requests.HTTPError as e:
167-
raise TransportServerError(str(e))
162+
raise TransportServerError(str(e), e.response.status_code) from e
163+
164+
result_text = resp.text
165+
raise TransportProtocolError(
166+
f"Server did not return a GraphQL result: "
167+
f"{reason}: "
168+
f"{result_text}"
169+
)
168170

169-
raise TransportProtocolError("Server did not return a GraphQL result")
171+
try:
172+
result = response.json()
173+
174+
if log.isEnabledFor(logging.INFO):
175+
log.info("<<< %s", response.text)
176+
177+
except Exception:
178+
raise_response_error(response, "Not a JSON answer")
170179

171180
if "errors" not in result and "data" not in result:
172-
raise TransportProtocolError("Server did not return a GraphQL result")
181+
raise_response_error(response, 'No "data" or "errors" keys in answer')
173182

174183
return ExecutionResult(
175184
errors=result.get("errors"),

tests/test_aiohttp.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,37 @@ async def handler(request):
102102
assert africa["code"] == "AF"
103103

104104

105+
@pytest.mark.asyncio
106+
async def test_aiohttp_error_code_401(event_loop, aiohttp_server):
107+
from aiohttp import web
108+
from gql.transport.aiohttp import AIOHTTPTransport
109+
110+
async def handler(request):
111+
# Will generate http error code 401
112+
return web.Response(
113+
text='{"error":"Unauthorized","message":"401 Client Error: Unauthorized"}',
114+
content_type="application/json",
115+
status=401,
116+
)
117+
118+
app = web.Application()
119+
app.router.add_route("POST", "/", handler)
120+
server = await aiohttp_server(app)
121+
122+
url = server.make_url("/")
123+
124+
sample_transport = AIOHTTPTransport(url=url)
125+
126+
async with Client(transport=sample_transport,) as session:
127+
128+
query = gql(query1_str)
129+
130+
with pytest.raises(TransportServerError) as exc_info:
131+
await session.execute(query)
132+
133+
assert "401, message='Unauthorized'" in str(exc_info.value)
134+
135+
105136
@pytest.mark.asyncio
106137
async def test_aiohttp_error_code_500(event_loop, aiohttp_server):
107138
from aiohttp import web
@@ -163,20 +194,20 @@ async def handler(request):
163194
"response": "{}",
164195
"expected_exception": (
165196
"Server did not return a GraphQL result: "
166-
'No "data" or "error" keys in answer: {}'
197+
'No "data" or "errors" keys in answer: {}'
167198
),
168199
},
169200
{
170201
"response": "qlsjfqsdlkj",
171202
"expected_exception": (
172-
"Server did not return a GraphQL result: " "qlsjfqsdlkj"
203+
"Server did not return a GraphQL result: Not a JSON answer: qlsjfqsdlkj"
173204
),
174205
},
175206
{
176207
"response": '{"not_data_or_errors": 35}',
177208
"expected_exception": (
178209
"Server did not return a GraphQL result: "
179-
'No "data" or "error" keys in answer: {"not_data_or_errors": 35}'
210+
'No "data" or "errors" keys in answer: {"not_data_or_errors": 35}'
180211
),
181212
},
182213
]

tests/test_client.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from gql import Client, gql
99
from gql.transport import Transport
10+
from gql.transport.exceptions import TransportQueryError
1011

1112
with suppress(ModuleNotFoundError):
1213
from urllib3.exceptions import NewConnectionError
@@ -105,7 +106,7 @@ def test_execute_result_error():
105106
"""
106107
)
107108

108-
with pytest.raises(Exception) as exc_info:
109+
with pytest.raises(TransportQueryError) as exc_info:
109110
client.execute(failing_query)
110111
assert 'Cannot query field "id" on type "Continent".' in str(exc_info.value)
111112

tests/test_requests.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,41 @@ def test_code():
101101
await run_sync_test(event_loop, server, test_code)
102102

103103

104+
@pytest.mark.aiohttp
105+
@pytest.mark.asyncio
106+
async def test_requests_error_code_401(event_loop, aiohttp_server, run_sync_test):
107+
from aiohttp import web
108+
from gql.transport.requests import RequestsHTTPTransport
109+
110+
async def handler(request):
111+
# Will generate http error code 401
112+
return web.Response(
113+
text='{"error":"Unauthorized","message":"401 Client Error: Unauthorized"}',
114+
content_type="application/json",
115+
status=401,
116+
)
117+
118+
app = web.Application()
119+
app.router.add_route("POST", "/", handler)
120+
server = await aiohttp_server(app)
121+
122+
url = server.make_url("/")
123+
124+
def test_code():
125+
sample_transport = RequestsHTTPTransport(url=url)
126+
127+
with Client(transport=sample_transport,) as session:
128+
129+
query = gql(query1_str)
130+
131+
with pytest.raises(TransportServerError) as exc_info:
132+
session.execute(query)
133+
134+
assert "401 Client Error: Unauthorized" in str(exc_info.value)
135+
136+
await run_sync_test(event_loop, server, test_code)
137+
138+
104139
@pytest.mark.aiohttp
105140
@pytest.mark.asyncio
106141
async def test_requests_error_code_500(event_loop, aiohttp_server, run_sync_test):

0 commit comments

Comments
 (0)