1
1
import asyncio
2
2
import json
3
+ import sys
4
+ from typing import List
3
5
4
6
import pytest
5
7
import websockets
17
19
WITH_KEEPALIVE = False
18
20
19
21
22
+ # List which can used to store received messages by the server
23
+ logged_messages : List [str ] = []
24
+
25
+
20
26
async def server_countdown (ws , path ):
27
+ logged_messages .clear ()
28
+
21
29
global WITH_KEEPALIVE
22
30
try :
23
31
await WebSocketServer .send_connection_ack (ws )
24
32
if WITH_KEEPALIVE :
25
33
await WebSocketServer .send_keepalive (ws )
26
34
27
35
result = await ws .recv ()
36
+ logged_messages .append (result )
37
+
28
38
json_result = json .loads (result )
29
39
assert json_result ["type" ] == "start"
30
40
payload = json_result ["payload" ]
@@ -48,7 +58,12 @@ async def stopping_coro():
48
58
nonlocal counting_task
49
59
while True :
50
60
51
- result = await ws .recv ()
61
+ try :
62
+ result = await ws .recv ()
63
+ logged_messages .append (result )
64
+ except websockets .exceptions .ConnectionClosed :
65
+ break
66
+
52
67
json_result = json .loads (result )
53
68
54
69
if json_result ["type" ] == "stop" and json_result ["id" ] == str (query_id ):
@@ -58,7 +73,10 @@ async def stopping_coro():
58
73
async def keepalive_coro ():
59
74
while True :
60
75
await asyncio .sleep (5 * MS )
61
- await WebSocketServer .send_keepalive (ws )
76
+ try :
77
+ await WebSocketServer .send_keepalive (ws )
78
+ except websockets .exceptions .ConnectionClosed :
79
+ break
62
80
63
81
stopping_task = asyncio .ensure_future (stopping_coro ())
64
82
keepalive_task = asyncio .ensure_future (keepalive_coro ())
@@ -351,3 +369,52 @@ def test_websocket_subscription_sync(server, subscription_str):
351
369
count -= 1
352
370
353
371
assert count == - 1
372
+
373
+
374
+ @pytest .mark .skipif (sys .platform .startswith ("win" ), reason = "test failing on windows" )
375
+ @pytest .mark .parametrize ("server" , [server_countdown ], indirect = True )
376
+ @pytest .mark .parametrize ("subscription_str" , [countdown_subscription_str ])
377
+ def test_websocket_subscription_sync_graceful_shutdown (server , subscription_str ):
378
+ """ Note: this test will simulate a control-C happening while a sync subscription
379
+ is in progress. To do that we will throw a KeyboardInterrupt exception inside
380
+ the subscription async generator.
381
+
382
+ The code should then do a clean close:
383
+ - send stop messages for each active query
384
+ - send a connection_terminate message
385
+ Then the KeyboardInterrupt will be reraise (to warn potential user code)
386
+
387
+ This test does not work on Windows but the behaviour with Windows is correct.
388
+ """
389
+
390
+ url = f"ws://{ server .hostname } :{ server .port } /graphql"
391
+ print (f"url = { url } " )
392
+
393
+ sample_transport = WebsocketsTransport (url = url )
394
+
395
+ client = Client (transport = sample_transport )
396
+
397
+ count = 10
398
+ subscription = gql (subscription_str .format (count = count ))
399
+
400
+ with pytest .raises (KeyboardInterrupt ):
401
+ for result in client .subscribe (subscription ):
402
+
403
+ number = result ["number" ]
404
+ print (f"Number received: { number } " )
405
+
406
+ assert number == count
407
+
408
+ if count == 5 :
409
+
410
+ # Simulate a KeyboardInterrupt in the generator
411
+ asyncio .ensure_future (
412
+ client .session ._generator .athrow (KeyboardInterrupt )
413
+ )
414
+
415
+ count -= 1
416
+
417
+ assert count == 4
418
+
419
+ # Check that the server received a connection_terminate message last
420
+ assert logged_messages .pop () == '{"type": "connection_terminate"}'
0 commit comments