Skip to content

Graceful shutdown for synchronous subscriptions #99

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions gql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,11 @@ async def subscribe_async(
) -> AsyncGenerator[Dict, None]:
async with self as session:

self._generator: AsyncGenerator[Dict, None] = session.subscribe(
generator: AsyncGenerator[Dict, None] = session.subscribe(
document, *args, **kwargs
)

async for result in self._generator:
async for result in generator:
yield result

def subscribe(
Expand All @@ -145,12 +145,29 @@ def subscribe(

try:
while True:
result = loop.run_until_complete(async_generator.__anext__())
# Note: we need to create a task here in order to be able to close
# the async generator properly on python 3.8
# See https://bugs.python.org/issue38559
generator_task = asyncio.ensure_future(async_generator.__anext__())
result = loop.run_until_complete(generator_task)
yield result

except StopAsyncIteration:
pass

except (KeyboardInterrupt, Exception):

# Graceful shutdown by cancelling the task and waiting clean shutdown
generator_task.cancel()

try:
loop.run_until_complete(generator_task)
except (StopAsyncIteration, asyncio.CancelledError):
pass

# Then reraise the exception
raise

async def __aenter__(self):

assert isinstance(
Expand Down
71 changes: 69 additions & 2 deletions tests/test_websocket_subscription.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import asyncio
import json
import sys
from typing import List

import pytest
import websockets
Expand All @@ -17,14 +19,22 @@
WITH_KEEPALIVE = False


# List which can used to store received messages by the server
logged_messages: List[str] = []


async def server_countdown(ws, path):
logged_messages.clear()

global WITH_KEEPALIVE
try:
await WebSocketServer.send_connection_ack(ws)
if WITH_KEEPALIVE:
await WebSocketServer.send_keepalive(ws)

result = await ws.recv()
logged_messages.append(result)

json_result = json.loads(result)
assert json_result["type"] == "start"
payload = json_result["payload"]
Expand All @@ -48,7 +58,12 @@ async def stopping_coro():
nonlocal counting_task
while True:

result = await ws.recv()
try:
result = await ws.recv()
logged_messages.append(result)
except websockets.exceptions.ConnectionClosed:
break

json_result = json.loads(result)

if json_result["type"] == "stop" and json_result["id"] == str(query_id):
Expand All @@ -58,7 +73,10 @@ async def stopping_coro():
async def keepalive_coro():
while True:
await asyncio.sleep(5 * MS)
await WebSocketServer.send_keepalive(ws)
try:
await WebSocketServer.send_keepalive(ws)
except websockets.exceptions.ConnectionClosed:
break

stopping_task = asyncio.ensure_future(stopping_coro())
keepalive_task = asyncio.ensure_future(keepalive_coro())
Expand Down Expand Up @@ -351,3 +369,52 @@ def test_websocket_subscription_sync(server, subscription_str):
count -= 1

assert count == -1


@pytest.mark.skipif(sys.platform.startswith("win"), reason="test failing on windows")
@pytest.mark.parametrize("server", [server_countdown], indirect=True)
@pytest.mark.parametrize("subscription_str", [countdown_subscription_str])
def test_websocket_subscription_sync_graceful_shutdown(server, subscription_str):
""" Note: this test will simulate a control-C happening while a sync subscription
is in progress. To do that we will throw a KeyboardInterrupt exception inside
the subscription async generator.

The code should then do a clean close:
- send stop messages for each active query
- send a connection_terminate message
Then the KeyboardInterrupt will be reraise (to warn potential user code)

This test does not work on Windows but the behaviour with Windows is correct.
"""

url = f"ws://{server.hostname}:{server.port}/graphql"
print(f"url = {url}")

sample_transport = WebsocketsTransport(url=url)

client = Client(transport=sample_transport)

count = 10
subscription = gql(subscription_str.format(count=count))

with pytest.raises(KeyboardInterrupt):
for result in client.subscribe(subscription):

number = result["number"]
print(f"Number received: {number}")

assert number == count

if count == 5:

# Simulate a KeyboardInterrupt in the generator
asyncio.ensure_future(
client.session._generator.athrow(KeyboardInterrupt)
)

count -= 1

assert count == 4

# Check that the server received a connection_terminate message last
assert logged_messages.pop() == '{"type": "connection_terminate"}'