Skip to content

Commit 140547d

Browse files
authored
Graceful shutdown for synchronous subscriptions (#99)
* Graceful shutdown for synchronous subscriptions * Ignore test on windows
1 parent b6b7734 commit 140547d

File tree

2 files changed

+89
-5
lines changed

2 files changed

+89
-5
lines changed

gql/client.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,11 +119,11 @@ async def subscribe_async(
119119
) -> AsyncGenerator[Dict, None]:
120120
async with self as session:
121121

122-
self._generator: AsyncGenerator[Dict, None] = session.subscribe(
122+
generator: AsyncGenerator[Dict, None] = session.subscribe(
123123
document, *args, **kwargs
124124
)
125125

126-
async for result in self._generator:
126+
async for result in generator:
127127
yield result
128128

129129
def subscribe(
@@ -145,12 +145,29 @@ def subscribe(
145145

146146
try:
147147
while True:
148-
result = loop.run_until_complete(async_generator.__anext__())
148+
# Note: we need to create a task here in order to be able to close
149+
# the async generator properly on python 3.8
150+
# See https://bugs.python.org/issue38559
151+
generator_task = asyncio.ensure_future(async_generator.__anext__())
152+
result = loop.run_until_complete(generator_task)
149153
yield result
150154

151155
except StopAsyncIteration:
152156
pass
153157

158+
except (KeyboardInterrupt, Exception):
159+
160+
# Graceful shutdown by cancelling the task and waiting clean shutdown
161+
generator_task.cancel()
162+
163+
try:
164+
loop.run_until_complete(generator_task)
165+
except (StopAsyncIteration, asyncio.CancelledError):
166+
pass
167+
168+
# Then reraise the exception
169+
raise
170+
154171
async def __aenter__(self):
155172

156173
assert isinstance(

tests/test_websocket_subscription.py

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import asyncio
22
import json
3+
import sys
4+
from typing import List
35

46
import pytest
57
import websockets
@@ -17,14 +19,22 @@
1719
WITH_KEEPALIVE = False
1820

1921

22+
# List which can used to store received messages by the server
23+
logged_messages: List[str] = []
24+
25+
2026
async def server_countdown(ws, path):
27+
logged_messages.clear()
28+
2129
global WITH_KEEPALIVE
2230
try:
2331
await WebSocketServer.send_connection_ack(ws)
2432
if WITH_KEEPALIVE:
2533
await WebSocketServer.send_keepalive(ws)
2634

2735
result = await ws.recv()
36+
logged_messages.append(result)
37+
2838
json_result = json.loads(result)
2939
assert json_result["type"] == "start"
3040
payload = json_result["payload"]
@@ -48,7 +58,12 @@ async def stopping_coro():
4858
nonlocal counting_task
4959
while True:
5060

51-
result = await ws.recv()
61+
try:
62+
result = await ws.recv()
63+
logged_messages.append(result)
64+
except websockets.exceptions.ConnectionClosed:
65+
break
66+
5267
json_result = json.loads(result)
5368

5469
if json_result["type"] == "stop" and json_result["id"] == str(query_id):
@@ -58,7 +73,10 @@ async def stopping_coro():
5873
async def keepalive_coro():
5974
while True:
6075
await asyncio.sleep(5 * MS)
61-
await WebSocketServer.send_keepalive(ws)
76+
try:
77+
await WebSocketServer.send_keepalive(ws)
78+
except websockets.exceptions.ConnectionClosed:
79+
break
6280

6381
stopping_task = asyncio.ensure_future(stopping_coro())
6482
keepalive_task = asyncio.ensure_future(keepalive_coro())
@@ -351,3 +369,52 @@ def test_websocket_subscription_sync(server, subscription_str):
351369
count -= 1
352370

353371
assert count == -1
372+
373+
374+
@pytest.mark.skipif(sys.platform.startswith("win"), reason="test failing on windows")
375+
@pytest.mark.parametrize("server", [server_countdown], indirect=True)
376+
@pytest.mark.parametrize("subscription_str", [countdown_subscription_str])
377+
def test_websocket_subscription_sync_graceful_shutdown(server, subscription_str):
378+
""" Note: this test will simulate a control-C happening while a sync subscription
379+
is in progress. To do that we will throw a KeyboardInterrupt exception inside
380+
the subscription async generator.
381+
382+
The code should then do a clean close:
383+
- send stop messages for each active query
384+
- send a connection_terminate message
385+
Then the KeyboardInterrupt will be reraise (to warn potential user code)
386+
387+
This test does not work on Windows but the behaviour with Windows is correct.
388+
"""
389+
390+
url = f"ws://{server.hostname}:{server.port}/graphql"
391+
print(f"url = {url}")
392+
393+
sample_transport = WebsocketsTransport(url=url)
394+
395+
client = Client(transport=sample_transport)
396+
397+
count = 10
398+
subscription = gql(subscription_str.format(count=count))
399+
400+
with pytest.raises(KeyboardInterrupt):
401+
for result in client.subscribe(subscription):
402+
403+
number = result["number"]
404+
print(f"Number received: {number}")
405+
406+
assert number == count
407+
408+
if count == 5:
409+
410+
# Simulate a KeyboardInterrupt in the generator
411+
asyncio.ensure_future(
412+
client.session._generator.athrow(KeyboardInterrupt)
413+
)
414+
415+
count -= 1
416+
417+
assert count == 4
418+
419+
# Check that the server received a connection_terminate message last
420+
assert logged_messages.pop() == '{"type": "connection_terminate"}'

0 commit comments

Comments
 (0)