Skip to content

Commit d3e88a0

Browse files
authored
Merge branch 'master' into websockets_transport_allow_to_specify_valid_subprotocols
2 parents dbf559f + d3be916 commit d3e88a0

File tree

4 files changed

+88
-40
lines changed

4 files changed

+88
-40
lines changed

gql/client.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -271,8 +271,15 @@ async def __aenter__(self):
271271
self.session = AsyncClientSession(client=self)
272272

273273
# Get schema from transport if needed
274-
if self.fetch_schema_from_transport and not self.schema:
275-
await self.session.fetch_schema()
274+
try:
275+
if self.fetch_schema_from_transport and not self.schema:
276+
await self.session.fetch_schema()
277+
except Exception:
278+
# we don't know what type of exception is thrown here because it
279+
# depends on the underlying transport; we just make sure that the
280+
# transport is closed and re-raise the exception
281+
await self.transport.close()
282+
raise
276283

277284
return self.session
278285

@@ -293,8 +300,15 @@ def __enter__(self):
293300
self.session = SyncClientSession(client=self)
294301

295302
# Get schema from transport if needed
296-
if self.fetch_schema_from_transport and not self.schema:
297-
self.session.fetch_schema()
303+
try:
304+
if self.fetch_schema_from_transport and not self.schema:
305+
self.session.fetch_schema()
306+
except Exception:
307+
# we don't know what type of exception is thrown here because it
308+
# depends on the underlying transport; we just make sure that the
309+
# transport is closed and re-raise the exception
310+
self.transport.close()
311+
raise
298312

299313
return self.session
300314

gql/transport/websockets.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,7 @@ def _parse_answer_graphqlws(
279279
- instead of a unidirectional keep-alive (ka) message from server to client,
280280
there is now the possibility to send bidirectional ping/pong messages
281281
- connection_ack has an optional payload
282+
- the 'error' answer type returns a list of errors instead of a single error
282283
"""
283284

284285
answer_type: str = ""
@@ -295,11 +296,11 @@ def _parse_answer_graphqlws(
295296

296297
payload = json_answer.get("payload")
297298

298-
if not isinstance(payload, dict):
299-
raise ValueError("payload is not a dict")
300-
301299
if answer_type == "next":
302300

301+
if not isinstance(payload, dict):
302+
raise ValueError("payload is not a dict")
303+
303304
if "errors" not in payload and "data" not in payload:
304305
raise ValueError(
305306
"payload does not contain 'data' or 'errors' fields"
@@ -316,8 +317,11 @@ def _parse_answer_graphqlws(
316317

317318
elif answer_type == "error":
318319

320+
if not isinstance(payload, list):
321+
raise ValueError("payload is not a list")
322+
319323
raise TransportQueryError(
320-
str(payload), query_id=answer_id, errors=[payload]
324+
str(payload[0]), query_id=answer_id, errors=payload
321325
)
322326

323327
elif answer_type in ["ping", "pong", "connection_ack"]:

tests/test_client.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,3 +200,50 @@ def test_gql():
200200
client = Client(schema=schema)
201201
result = client.execute(query)
202202
assert result["user"] is None
203+
204+
205+
@pytest.mark.requests
206+
def test_sync_transport_close_on_schema_retrieval_failure():
207+
"""
208+
Ensure that the transport session is closed if an error occurs when
209+
entering the context manager (e.g., because schema retrieval fails)
210+
"""
211+
212+
from gql.transport.requests import RequestsHTTPTransport
213+
214+
transport = RequestsHTTPTransport(url="http://localhost/")
215+
client = Client(transport=transport, fetch_schema_from_transport=True)
216+
217+
try:
218+
with client:
219+
pass
220+
except Exception:
221+
# we don't care what exception is thrown, we just want to check if the
222+
# transport is closed afterwards
223+
pass
224+
225+
assert client.transport.session is None
226+
227+
228+
@pytest.mark.aiohttp
229+
@pytest.mark.asyncio
230+
async def test_async_transport_close_on_schema_retrieval_failure():
231+
"""
232+
Ensure that the transport session is closed if an error occurs when
233+
entering the context manager (e.g., because schema retrieval fails)
234+
"""
235+
236+
from gql.transport.aiohttp import AIOHTTPTransport
237+
238+
transport = AIOHTTPTransport(url="http://localhost/")
239+
client = Client(transport=transport, fetch_schema_from_transport=True)
240+
241+
try:
242+
async with client:
243+
pass
244+
except Exception:
245+
# we don't care what exception is thrown, we just want to check if the
246+
# transport is closed afterwards
247+
pass
248+
249+
assert client.transport.session is None

tests/test_graphqlws_exceptions.py

Lines changed: 15 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
import asyncio
2-
import json
3-
import types
42
from typing import List
53

64
import pytest
@@ -125,49 +123,29 @@ async def test_graphqlws_server_does_not_send_ack(
125123
pass
126124

127125

128-
invalid_payload_server_answer = (
129-
'{"type":"error","id":"1","payload":{"message":"Must provide document"}}'
126+
invalid_query_server_answer = (
127+
'{"id":"1","type":"error","payload":[{"message":"Cannot query field '
128+
'\\"helo\\" on type \\"Query\\". Did you mean \\"hello\\"?",'
129+
'"locations":[{"line":2,"column":3}]}]}'
130130
)
131131

132132

133-
async def server_invalid_payload(ws, path):
133+
async def server_invalid_query(ws, path):
134134
await WebSocketServerHelper.send_connection_ack(ws)
135135
result = await ws.recv()
136136
print(f"Server received: {result}")
137-
await ws.send(invalid_payload_server_answer)
137+
await ws.send(invalid_query_server_answer)
138138
await WebSocketServerHelper.wait_connection_terminate(ws)
139139
await ws.wait_closed()
140140

141141

142142
@pytest.mark.asyncio
143-
@pytest.mark.parametrize("graphqlws_server", [server_invalid_payload], indirect=True)
144-
@pytest.mark.parametrize("query_str", [invalid_query_str])
145-
async def test_graphqlws_sending_invalid_payload(
146-
event_loop, client_and_graphqlws_server, query_str
147-
):
143+
@pytest.mark.parametrize("graphqlws_server", [server_invalid_query], indirect=True)
144+
async def test_graphqlws_sending_invalid_query(event_loop, client_and_graphqlws_server):
148145

149146
session, server = client_and_graphqlws_server
150147

151-
# Monkey patching the _send_query method to send an invalid payload
152-
153-
async def monkey_patch_send_query(
154-
self, document, variable_values=None, operation_name=None,
155-
) -> int:
156-
query_id = self.next_query_id
157-
self.next_query_id += 1
158-
159-
query_str = json.dumps(
160-
{"id": str(query_id), "type": "subscribe", "payload": "BLAHBLAH"}
161-
)
162-
163-
await self._send(query_str)
164-
return query_id
165-
166-
session.transport._send_query = types.MethodType(
167-
monkey_patch_send_query, session.transport
168-
)
169-
170-
query = gql(query_str)
148+
query = gql("{helo}")
171149

172150
with pytest.raises(TransportQueryError) as exc_info:
173151
await session.execute(query)
@@ -178,7 +156,10 @@ async def monkey_patch_send_query(
178156

179157
error = exception.errors[0]
180158

181-
assert error["message"] == "Must provide document"
159+
assert (
160+
error["message"]
161+
== 'Cannot query field "helo" on type "Query". Did you mean "hello"?'
162+
)
182163

183164

184165
not_json_answer = ["BLAHBLAH"]
@@ -188,6 +169,7 @@ async def monkey_patch_send_query(
188169
missing_id_answer_3 = ['{"type": "complete"}']
189170
data_without_payload = ['{"type": "next", "id":"1"}']
190171
error_without_payload = ['{"type": "error", "id":"1"}']
172+
error_with_payload_not_a_list = ['{"type": "error", "id":"1", "payload": "NOT A LIST"}']
191173
payload_is_not_a_dict = ['{"type": "next", "id":"1", "payload": "BLAH"}']
192174
empty_payload = ['{"type": "next", "id":"1", "payload": {}}']
193175
sending_bytes = [b"\x01\x02\x03"]
@@ -205,6 +187,7 @@ async def monkey_patch_send_query(
205187
data_without_payload,
206188
error_without_payload,
207189
payload_is_not_a_dict,
190+
error_with_payload_not_a_list,
208191
empty_payload,
209192
sending_bytes,
210193
],

0 commit comments

Comments
 (0)