Skip to content

Commit 26c577f

Browse files
committed
Graceful shutdown for synchronous subscriptions
1 parent ce13236 commit 26c577f

File tree

2 files changed

+76
-5
lines changed

2 files changed

+76
-5
lines changed

gql/client.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,11 +120,11 @@ async def subscribe_async(
120120
) -> AsyncGenerator[Dict, None]:
121121
async with self as session:
122122

123-
self._generator: AsyncGenerator[Dict, None] = session.subscribe(
123+
generator: AsyncGenerator[Dict, None] = session.subscribe(
124124
document, *args, **kwargs
125125
)
126126

127-
async for result in self._generator:
127+
async for result in generator:
128128
yield result
129129

130130
def subscribe(
@@ -146,12 +146,29 @@ def subscribe(
146146

147147
try:
148148
while True:
149-
result = loop.run_until_complete(async_generator.__anext__())
149+
# Note: we need to create a task here in order to be able to close
150+
# the async generator properly on python 3.8
151+
# See https://bugs.python.org/issue38559
152+
generator_task = asyncio.ensure_future(async_generator.__anext__())
153+
result = loop.run_until_complete(generator_task)
150154
yield result
151155

152156
except StopAsyncIteration:
153157
pass
154158

159+
except (KeyboardInterrupt, Exception):
160+
161+
# Graceful shutdown by cancelling the task and waiting clean shutdown
162+
generator_task.cancel()
163+
164+
try:
165+
loop.run_until_complete(generator_task)
166+
except (StopAsyncIteration, asyncio.CancelledError):
167+
pass
168+
169+
# Then reraise the exception
170+
raise
171+
155172
async def __aenter__(self):
156173

157174
assert isinstance(

tests/test_websocket_subscription.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import json
3+
from typing import List
34

45
import pytest
56
import websockets
@@ -17,14 +18,22 @@
1718
WITH_KEEPALIVE = False
1819

1920

21+
# List which can used to store received messages by the server
22+
logged_messages: List[str] = []
23+
24+
2025
async def server_countdown(ws, path):
26+
logged_messages.clear()
27+
2128
global WITH_KEEPALIVE
2229
try:
2330
await WebSocketServer.send_connection_ack(ws)
2431
if WITH_KEEPALIVE:
2532
await WebSocketServer.send_keepalive(ws)
2633

2734
result = await ws.recv()
35+
logged_messages.append(result)
36+
2837
json_result = json.loads(result)
2938
assert json_result["type"] == "start"
3039
payload = json_result["payload"]
@@ -48,7 +57,12 @@ async def stopping_coro():
4857
nonlocal counting_task
4958
while True:
5059

51-
result = await ws.recv()
60+
try:
61+
result = await ws.recv()
62+
logged_messages.append(result)
63+
except websockets.exceptions.ConnectionClosed:
64+
break
65+
5266
json_result = json.loads(result)
5367

5468
if json_result["type"] == "stop" and json_result["id"] == str(query_id):
@@ -58,7 +72,10 @@ async def stopping_coro():
5872
async def keepalive_coro():
5973
while True:
6074
await asyncio.sleep(5 * MS)
61-
await WebSocketServer.send_keepalive(ws)
75+
try:
76+
await WebSocketServer.send_keepalive(ws)
77+
except websockets.exceptions.ConnectionClosed:
78+
break
6279

6380
stopping_task = asyncio.ensure_future(stopping_coro())
6481
keepalive_task = asyncio.ensure_future(keepalive_coro())
@@ -351,3 +368,40 @@ def test_websocket_subscription_sync(server, subscription_str):
351368
count -= 1
352369

353370
assert count == -1
371+
372+
373+
@pytest.mark.parametrize("server", [server_countdown], indirect=True)
374+
@pytest.mark.parametrize("subscription_str", [countdown_subscription_str])
375+
def test_websocket_subscription_sync_graceful_shutdown(server, subscription_str):
376+
377+
url = f"ws://{server.hostname}:{server.port}/graphql"
378+
print(f"url = {url}")
379+
380+
sample_transport = WebsocketsTransport(url=url)
381+
382+
client = Client(transport=sample_transport)
383+
384+
count = 10
385+
subscription = gql(subscription_str.format(count=count))
386+
387+
with pytest.raises(KeyboardInterrupt):
388+
for result in client.subscribe(subscription):
389+
390+
number = result["number"]
391+
print(f"Number received: {number}")
392+
393+
assert number == count
394+
395+
if count == 5:
396+
397+
# Simulate a KeyboardInterrupt in the generator
398+
asyncio.ensure_future(
399+
client.session._generator.athrow(KeyboardInterrupt)
400+
)
401+
402+
count -= 1
403+
404+
assert count == 4
405+
406+
# Check that the server received a connection_terminate message last
407+
assert logged_messages.pop() == '{"type": "connection_terminate"}'

0 commit comments

Comments
 (0)