Skip to content

Commit 905b724

Browse files
authored
Fix sync subscribe graceful shutdown (#395)
1 parent 5e37e6a commit 905b724

File tree

2 files changed

+65
-6
lines changed

2 files changed

+65
-6
lines changed

gql/client.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -593,15 +593,14 @@ def subscribe(
593593
except StopAsyncIteration:
594594
pass
595595

596-
except (KeyboardInterrupt, Exception):
596+
except (KeyboardInterrupt, Exception, GeneratorExit):
597+
598+
# Graceful shutdown
599+
asyncio.ensure_future(async_generator.aclose(), loop=loop)
597600

598-
# Graceful shutdown by cancelling the task and waiting clean shutdown
599601
generator_task.cancel()
600602

601-
try:
602-
loop.run_until_complete(generator_task)
603-
except (StopAsyncIteration, asyncio.CancelledError):
604-
pass
603+
loop.run_until_complete(loop.shutdown_asyncgens())
605604

606605
# Then reraise the exception
607606
raise

tests/test_websocket_subscription.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,66 @@ def test_websocket_subscription_sync(server, subscription_str):
494494
assert count == -1
495495

496496

497+
@pytest.mark.parametrize("server", [server_countdown], indirect=True)
498+
@pytest.mark.parametrize("subscription_str", [countdown_subscription_str])
499+
def test_websocket_subscription_sync_user_exception(server, subscription_str):
500+
from gql.transport.websockets import WebsocketsTransport
501+
502+
url = f"ws://{server.hostname}:{server.port}/graphql"
503+
print(f"url = {url}")
504+
505+
sample_transport = WebsocketsTransport(url=url)
506+
507+
client = Client(transport=sample_transport)
508+
509+
count = 10
510+
subscription = gql(subscription_str.format(count=count))
511+
512+
with pytest.raises(Exception) as exc_info:
513+
for result in client.subscribe(subscription):
514+
515+
number = result["number"]
516+
print(f"Number received: {number}")
517+
518+
assert number == count
519+
count -= 1
520+
521+
if count == 5:
522+
raise Exception("This is an user exception")
523+
524+
assert count == 5
525+
assert "This is an user exception" in str(exc_info.value)
526+
527+
528+
@pytest.mark.parametrize("server", [server_countdown], indirect=True)
529+
@pytest.mark.parametrize("subscription_str", [countdown_subscription_str])
530+
def test_websocket_subscription_sync_break(server, subscription_str):
531+
from gql.transport.websockets import WebsocketsTransport
532+
533+
url = f"ws://{server.hostname}:{server.port}/graphql"
534+
print(f"url = {url}")
535+
536+
sample_transport = WebsocketsTransport(url=url)
537+
538+
client = Client(transport=sample_transport)
539+
540+
count = 10
541+
subscription = gql(subscription_str.format(count=count))
542+
543+
for result in client.subscribe(subscription):
544+
545+
number = result["number"]
546+
print(f"Number received: {number}")
547+
548+
assert number == count
549+
count -= 1
550+
551+
if count == 5:
552+
break
553+
554+
assert count == 5
555+
556+
497557
@pytest.mark.skipif(sys.platform.startswith("win"), reason="test failing on windows")
498558
@pytest.mark.parametrize("server", [server_countdown], indirect=True)
499559
@pytest.mark.parametrize("subscription_str", [countdown_subscription_str])

0 commit comments

Comments
 (0)