Skip to content

Commit b6b7734

Browse files
authored
Ignore keepalives before connection_ack in websockets transport (#110)
1 parent 29f3fa7 commit b6b7734

File tree

3 files changed

+59
-9
lines changed

3 files changed

+59
-9
lines changed

gql/transport/websockets.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,23 @@ async def _receive(self) -> str:
175175

176176
return answer
177177

178+
async def _wait_ack(self) -> None:
179+
"""Wait for the connection_ack message. Keep alive messages are ignored
180+
"""
181+
182+
while True:
183+
init_answer = await self._receive()
184+
185+
answer_type, answer_id, execution_result = self._parse_answer(init_answer)
186+
187+
if answer_type == "connection_ack":
188+
return
189+
190+
if answer_type != "ka":
191+
raise TransportProtocolError(
192+
"Websocket server did not return a connection ack"
193+
)
194+
178195
async def _send_init_message_and_wait_ack(self) -> None:
179196
"""Send init message to the provided websocket and wait for the connection ACK.
180197
@@ -188,14 +205,7 @@ async def _send_init_message_and_wait_ack(self) -> None:
188205
await self._send(init_message)
189206

190207
# Wait for the connection_ack message or raise a TimeoutError
191-
init_answer = await asyncio.wait_for(self._receive(), self.ack_timeout)
192-
193-
answer_type, answer_id, execution_result = self._parse_answer(init_answer)
194-
195-
if answer_type != "connection_ack":
196-
raise TransportProtocolError(
197-
"Websocket server did not return a connection ack"
198-
)
208+
await asyncio.wait_for(self._wait_ack(), self.ack_timeout)
199209

200210
async def _send_stop_message(self, query_id: int) -> None:
201211
"""Send stop message to the provided websocket connection and query_id.

tests/test_websocket_exceptions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ async def test_websocket_transport_protocol_errors(event_loop, client_and_server
241241

242242
async def server_without_ack(ws, path):
243243
# Sending something else than an ack
244-
await WebSocketServer.send_keepalive(ws)
244+
await WebSocketServer.send_complete(ws, 1)
245245
await ws.wait_closed()
246246

247247

tests/test_websocket_query.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,3 +472,43 @@ async def test_websocket_add_extra_parameters_to_connect(event_loop, server):
472472

473473
async with Client(transport=sample_transport) as session:
474474
await session.execute(query)
475+
476+
477+
async def server_sending_keep_alive_before_connection_ack(ws, path):
478+
await WebSocketServer.send_keepalive(ws)
479+
await WebSocketServer.send_keepalive(ws)
480+
await WebSocketServer.send_keepalive(ws)
481+
await WebSocketServer.send_keepalive(ws)
482+
await WebSocketServer.send_connection_ack(ws)
483+
result = await ws.recv()
484+
print(f"Server received: {result}")
485+
await ws.send(query1_server_answer.format(query_id=1))
486+
await WebSocketServer.send_complete(ws, 1)
487+
await ws.wait_closed()
488+
489+
490+
@pytest.mark.asyncio
491+
@pytest.mark.parametrize(
492+
"server", [server_sending_keep_alive_before_connection_ack], indirect=True
493+
)
494+
@pytest.mark.parametrize("query_str", [query1_str])
495+
async def test_websocket_non_regression_bug_108(
496+
event_loop, client_and_server, query_str
497+
):
498+
499+
# This test will check that we now ignore keepalive message
500+
# arriving before the connection_ack
501+
# See bug #108
502+
503+
session, server = client_and_server
504+
505+
query = gql(query_str)
506+
507+
result = await session.execute(query)
508+
509+
print("Client received:", result)
510+
511+
continents = result["continents"]
512+
africa = continents[0]
513+
514+
assert africa["code"] == "AF"

0 commit comments

Comments
 (0)