1
1
import asyncio
2
2
import json
3
+ from typing import List
3
4
4
5
import pytest
5
6
import websockets
17
18
WITH_KEEPALIVE = False
18
19
19
20
21
+ # List which can used to store received messages by the server
22
+ logged_messages : List [str ] = []
23
+
24
+
20
25
async def server_countdown (ws , path ):
26
+ logged_messages .clear ()
27
+
21
28
global WITH_KEEPALIVE
22
29
try :
23
30
await WebSocketServer .send_connection_ack (ws )
24
31
if WITH_KEEPALIVE :
25
32
await WebSocketServer .send_keepalive (ws )
26
33
27
34
result = await ws .recv ()
35
+ logged_messages .append (result )
36
+
28
37
json_result = json .loads (result )
29
38
assert json_result ["type" ] == "start"
30
39
payload = json_result ["payload" ]
@@ -48,7 +57,12 @@ async def stopping_coro():
48
57
nonlocal counting_task
49
58
while True :
50
59
51
- result = await ws .recv ()
60
+ try :
61
+ result = await ws .recv ()
62
+ logged_messages .append (result )
63
+ except websockets .exceptions .ConnectionClosed :
64
+ break
65
+
52
66
json_result = json .loads (result )
53
67
54
68
if json_result ["type" ] == "stop" and json_result ["id" ] == str (query_id ):
@@ -58,7 +72,10 @@ async def stopping_coro():
58
72
async def keepalive_coro ():
59
73
while True :
60
74
await asyncio .sleep (5 * MS )
61
- await WebSocketServer .send_keepalive (ws )
75
+ try :
76
+ await WebSocketServer .send_keepalive (ws )
77
+ except websockets .exceptions .ConnectionClosed :
78
+ break
62
79
63
80
stopping_task = asyncio .ensure_future (stopping_coro ())
64
81
keepalive_task = asyncio .ensure_future (keepalive_coro ())
@@ -351,3 +368,40 @@ def test_websocket_subscription_sync(server, subscription_str):
351
368
count -= 1
352
369
353
370
assert count == - 1
371
+
372
+
373
+ @pytest .mark .parametrize ("server" , [server_countdown ], indirect = True )
374
+ @pytest .mark .parametrize ("subscription_str" , [countdown_subscription_str ])
375
+ def test_websocket_subscription_sync_graceful_shutdown (server , subscription_str ):
376
+
377
+ url = f"ws://{ server .hostname } :{ server .port } /graphql"
378
+ print (f"url = { url } " )
379
+
380
+ sample_transport = WebsocketsTransport (url = url )
381
+
382
+ client = Client (transport = sample_transport )
383
+
384
+ count = 10
385
+ subscription = gql (subscription_str .format (count = count ))
386
+
387
+ with pytest .raises (KeyboardInterrupt ):
388
+ for result in client .subscribe (subscription ):
389
+
390
+ number = result ["number" ]
391
+ print (f"Number received: { number } " )
392
+
393
+ assert number == count
394
+
395
+ if count == 5 :
396
+
397
+ # Simulate a KeyboardInterrupt in the generator
398
+ asyncio .ensure_future (
399
+ client .session ._generator .athrow (KeyboardInterrupt )
400
+ )
401
+
402
+ count -= 1
403
+
404
+ assert count == 4
405
+
406
+ # Check that the server received a connection_terminate message last
407
+ assert logged_messages .pop () == '{"type": "connection_terminate"}'
0 commit comments