Skip to content

Commit f51179e

Browse files
authored
Merge branch 'master' into fix_race_condition_in_websockets_connect
2 parents ac28f70 + 87cc5b2 commit f51179e

File tree

4 files changed

+89
-22
lines changed

4 files changed

+89
-22
lines changed

gql/client.py

Lines changed: 45 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import asyncio
2-
from inspect import isawaitable
3-
from typing import Any, AsyncGenerator, Dict, Generator, Optional, Union, cast
2+
from typing import Any, AsyncGenerator, Dict, Generator, Optional, Union
43

54
from graphql import (
65
DocumentNode,
@@ -196,19 +195,22 @@ class SyncClientSession:
196195
def __init__(self, client: Client):
197196
self.client = client
198197

199-
def execute(self, document: DocumentNode, *args, **kwargs) -> Dict:
198+
def _execute(self, document: DocumentNode, *args, **kwargs) -> ExecutionResult:
200199

201200
# Validate document
202201
if self.client.schema:
203202
self.client.validate(document)
204203

205-
result = self.transport.execute(document, *args, **kwargs)
204+
return self.transport.execute(document, *args, **kwargs)
205+
206+
def execute(self, document: DocumentNode, *args, **kwargs) -> Dict:
206207

207-
assert not isawaitable(result), "Transport returned an awaitable result."
208-
result = cast(ExecutionResult, result)
208+
# Validate and execute on the transport
209+
result = self._execute(document, *args, **kwargs)
209210

211+
# Raise an error if an error is returned in the ExecutionResult object
210212
if result.errors:
211-
raise TransportQueryError(str(result.errors[0]))
213+
raise TransportQueryError(str(result.errors[0]), errors=result.errors)
212214

213215
assert (
214216
result.data is not None
@@ -250,43 +252,69 @@ async def fetch_and_validate(self, document: DocumentNode):
250252
if self.client.schema:
251253
self.client.validate(document)
252254

253-
async def subscribe(
255+
async def _subscribe(
254256
self, document: DocumentNode, *args, **kwargs
255-
) -> AsyncGenerator[Dict, None]:
257+
) -> AsyncGenerator[ExecutionResult, None]:
256258

257259
# Fetch schema from transport if needed and validate document if possible
258260
await self.fetch_and_validate(document)
259261

260-
# Subscribe to the transport and yield data or raise error
261-
self._generator: AsyncGenerator[
262+
# Subscribe to the transport
263+
inner_generator: AsyncGenerator[
262264
ExecutionResult, None
263265
] = self.transport.subscribe(document, *args, **kwargs)
264266

265-
async for result in self._generator:
267+
# Keep a reference to the inner generator to allow the user to call aclose()
268+
# before a break if python version is too old (pypy3 py 3.6.1)
269+
self._generator = inner_generator
270+
271+
async for result in inner_generator:
266272
if result.errors:
267273
# Note: we need to run generator.aclose() here or the finally block in
268274
# transport.subscribe will not be reached in pypy3 (py 3.6.1)
269-
await self._generator.aclose()
275+
await inner_generator.aclose()
276+
277+
yield result
278+
279+
async def subscribe(
280+
self, document: DocumentNode, *args, **kwargs
281+
) -> AsyncGenerator[Dict, None]:
270282

271-
raise TransportQueryError(str(result.errors[0]))
283+
# Validate and subscribe on the transport
284+
async for result in self._subscribe(document, *args, **kwargs):
285+
286+
# Raise an error if an error is returned in the ExecutionResult object
287+
if result.errors:
288+
raise TransportQueryError(str(result.errors[0]), errors=result.errors)
272289

273290
elif result.data is not None:
274291
yield result.data
275292

276-
async def execute(self, document: DocumentNode, *args, **kwargs) -> Dict:
293+
async def _execute(
294+
self, document: DocumentNode, *args, **kwargs
295+
) -> ExecutionResult:
277296

278297
# Fetch schema from transport if needed and validate document if possible
279298
await self.fetch_and_validate(document)
280299

281300
# Execute the query with the transport with a timeout
282-
result = await asyncio.wait_for(
301+
return await asyncio.wait_for(
283302
self.transport.execute(document, *args, **kwargs),
284303
self.client.execute_timeout,
285304
)
286305

306+
async def execute(self, document: DocumentNode, *args, **kwargs) -> Dict:
307+
308+
# Validate and execute on the transport
309+
result = await self._execute(document, *args, **kwargs)
310+
287311
# Raise an error if an error is returned in the ExecutionResult object
288312
if result.errors:
289-
raise TransportQueryError(str(result.errors[0]))
313+
raise TransportQueryError(str(result.errors[0]), errors=result.errors)
314+
315+
assert (
316+
result.data is not None
317+
), "Transport returned an ExecutionResult without data or errors"
290318

291319
return result.data
292320

gql/transport/exceptions.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
from typing import Any, List, Optional
2+
3+
14
class TransportError(Exception):
25
pass
36

@@ -22,9 +25,15 @@ class TransportQueryError(Exception):
2225
This exception should not close the transport connection.
2326
"""
2427

25-
def __init__(self, msg, query_id=None):
28+
def __init__(
29+
self,
30+
msg: str,
31+
query_id: Optional[int] = None,
32+
errors: Optional[List[Any]] = None,
33+
):
2634
super().__init__(msg)
2735
self.query_id = query_id
36+
self.errors = errors
2837

2938

3039
class TransportClosed(TransportError):

gql/transport/websockets.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,9 @@ def _parse_answer(
293293

294294
elif answer_type == "error":
295295

296-
raise TransportQueryError(str(payload), query_id=answer_id)
296+
raise TransportQueryError(
297+
str(payload), query_id=answer_id, errors=[payload]
298+
)
297299

298300
elif answer_type == "ka":
299301
# KeepAlive message
@@ -335,6 +337,9 @@ async def _receive_data_loop(self) -> None:
335337
# ==> Add an exception to this query queue
336338
# The exception is raised for this specific query,
337339
# but the transport is not closed.
340+
assert isinstance(
341+
e.query_id, int
342+
), "TransportQueryError should have a query_id defined here"
338343
try:
339344
await self.listeners[e.query_id].set_exception(e)
340345
except KeyError:

tests/test_websocket_exceptions.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22
import json
33
import types
4+
from typing import List
45

56
import pytest
67
import websockets
@@ -45,9 +46,17 @@ async def test_websocket_invalid_query(event_loop, client_and_server, query_str)
4546

4647
query = gql(query_str)
4748

48-
with pytest.raises(TransportQueryError):
49+
with pytest.raises(TransportQueryError) as exc_info:
4950
await session.execute(query)
5051

52+
exception = exc_info.value
53+
54+
assert isinstance(exception.errors, List)
55+
56+
error = exception.errors[0]
57+
58+
assert error["extensions"]["code"] == "INTERNAL_SERVER_ERROR"
59+
5160

5261
invalid_subscription_str = """
5362
subscription getContinents {
@@ -76,10 +85,18 @@ async def test_websocket_invalid_subscription(event_loop, client_and_server, que
7685

7786
query = gql(query_str)
7887

79-
with pytest.raises(TransportQueryError):
88+
with pytest.raises(TransportQueryError) as exc_info:
8089
async for result in session.subscribe(query):
8190
pass
8291

92+
exception = exc_info.value
93+
94+
assert isinstance(exception.errors, List)
95+
96+
error = exception.errors[0]
97+
98+
assert error["extensions"]["code"] == "INTERNAL_SERVER_ERROR"
99+
83100

84101
connection_error_server_answer = (
85102
'{"type":"connection_error","id":null,'
@@ -171,9 +188,17 @@ async def monkey_patch_send_query(
171188

172189
query = gql(query_str)
173190

174-
with pytest.raises(TransportQueryError):
191+
with pytest.raises(TransportQueryError) as exc_info:
175192
await session.execute(query)
176193

194+
exception = exc_info.value
195+
196+
assert isinstance(exception.errors, List)
197+
198+
error = exception.errors[0]
199+
200+
assert error["message"] == "Must provide document"
201+
177202

178203
not_json_answer = ["BLAHBLAH"]
179204
missing_type_answer = ["{}"]

0 commit comments

Comments
 (0)