From 26c577f42cfff408870ec1eb5437e04712980562 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Wed, 3 Jun 2020 22:49:07 +0200 Subject: [PATCH 1/2] Graceful shutdown for synchronous subscriptions --- gql/client.py | 23 +++++++++-- tests/test_websocket_subscription.py | 58 +++++++++++++++++++++++++++- 2 files changed, 76 insertions(+), 5 deletions(-) diff --git a/gql/client.py b/gql/client.py index 76cc2395..46d74f4e 100644 --- a/gql/client.py +++ b/gql/client.py @@ -120,11 +120,11 @@ async def subscribe_async( ) -> AsyncGenerator[Dict, None]: async with self as session: - self._generator: AsyncGenerator[Dict, None] = session.subscribe( + generator: AsyncGenerator[Dict, None] = session.subscribe( document, *args, **kwargs ) - async for result in self._generator: + async for result in generator: yield result def subscribe( @@ -146,12 +146,29 @@ def subscribe( try: while True: - result = loop.run_until_complete(async_generator.__anext__()) + # Note: we need to create a task here in order to be able to close + # the async generator properly on python 3.8 + # See https://bugs.python.org/issue38559 + generator_task = asyncio.ensure_future(async_generator.__anext__()) + result = loop.run_until_complete(generator_task) yield result except StopAsyncIteration: pass + except (KeyboardInterrupt, Exception): + + # Graceful shutdown by cancelling the task and waiting clean shutdown + generator_task.cancel() + + try: + loop.run_until_complete(generator_task) + except (StopAsyncIteration, asyncio.CancelledError): + pass + + # Then reraise the exception + raise + async def __aenter__(self): assert isinstance( diff --git a/tests/test_websocket_subscription.py b/tests/test_websocket_subscription.py index 2a9942ff..c2ec7335 100644 --- a/tests/test_websocket_subscription.py +++ b/tests/test_websocket_subscription.py @@ -1,5 +1,6 @@ import asyncio import json +from typing import List import pytest import websockets @@ -17,7 +18,13 @@ WITH_KEEPALIVE = False +# List which can used to store received messages by the server +logged_messages: List[str] = [] + + async def server_countdown(ws, path): + logged_messages.clear() + global WITH_KEEPALIVE try: await WebSocketServer.send_connection_ack(ws) @@ -25,6 +32,8 @@ async def server_countdown(ws, path): await WebSocketServer.send_keepalive(ws) result = await ws.recv() + logged_messages.append(result) + json_result = json.loads(result) assert json_result["type"] == "start" payload = json_result["payload"] @@ -48,7 +57,12 @@ async def stopping_coro(): nonlocal counting_task while True: - result = await ws.recv() + try: + result = await ws.recv() + logged_messages.append(result) + except websockets.exceptions.ConnectionClosed: + break + json_result = json.loads(result) if json_result["type"] == "stop" and json_result["id"] == str(query_id): @@ -58,7 +72,10 @@ async def stopping_coro(): async def keepalive_coro(): while True: await asyncio.sleep(5 * MS) - await WebSocketServer.send_keepalive(ws) + try: + await WebSocketServer.send_keepalive(ws) + except websockets.exceptions.ConnectionClosed: + break stopping_task = asyncio.ensure_future(stopping_coro()) keepalive_task = asyncio.ensure_future(keepalive_coro()) @@ -351,3 +368,40 @@ def test_websocket_subscription_sync(server, subscription_str): count -= 1 assert count == -1 + + +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +def test_websocket_subscription_sync_graceful_shutdown(server, subscription_str): + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + sample_transport = WebsocketsTransport(url=url) + + client = Client(transport=sample_transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + with pytest.raises(KeyboardInterrupt): + for result in client.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + + if count == 5: + + # Simulate a KeyboardInterrupt in the generator + asyncio.ensure_future( + client.session._generator.athrow(KeyboardInterrupt) + ) + + count -= 1 + + assert count == 4 + + # Check that the server received a connection_terminate message last + assert logged_messages.pop() == '{"type": "connection_terminate"}' From f51dd703e120920d968319ea3d20d4a6921f30fa Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sun, 5 Jul 2020 18:40:10 +0200 Subject: [PATCH 2/2] Ignore test on windows --- tests/test_websocket_subscription.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/test_websocket_subscription.py b/tests/test_websocket_subscription.py index c2ec7335..93cbde4b 100644 --- a/tests/test_websocket_subscription.py +++ b/tests/test_websocket_subscription.py @@ -1,5 +1,6 @@ import asyncio import json +import sys from typing import List import pytest @@ -370,9 +371,21 @@ def test_websocket_subscription_sync(server, subscription_str): assert count == -1 +@pytest.mark.skipif(sys.platform.startswith("win"), reason="test failing on windows") @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) def test_websocket_subscription_sync_graceful_shutdown(server, subscription_str): + """ Note: this test will simulate a control-C happening while a sync subscription + is in progress. To do that we will throw a KeyboardInterrupt exception inside + the subscription async generator. + + The code should then do a clean close: + - send stop messages for each active query + - send a connection_terminate message + Then the KeyboardInterrupt will be reraise (to warn potential user code) + + This test does not work on Windows but the behaviour with Windows is correct. + """ url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}")