diff --git a/docs/modules/transport.rst b/docs/modules/transport.rst index 9a3caf6e..1b250d7a 100644 --- a/docs/modules/transport.rst +++ b/docs/modules/transport.rst @@ -14,3 +14,5 @@ gql.transport .. autoclass:: gql.transport.aiohttp.AIOHTTPTransport .. autoclass:: gql.transport.websockets.WebsocketsTransport + +.. autoclass:: gql.transport.phoenix_channel_websockets.PhoenixChannelWebsocketsTransport diff --git a/docs/transports/aiohttp.rst b/docs/transports/aiohttp.rst index 4b792232..91f2bf40 100644 --- a/docs/transports/aiohttp.rst +++ b/docs/transports/aiohttp.rst @@ -5,6 +5,8 @@ AIOHTTPTransport This transport uses the `aiohttp`_ library and allows you to send GraphQL queries using the HTTP protocol. +Reference: :py:class:`gql.transport.aiohttp.AIOHTTPTransport` + .. note:: GraphQL subscriptions are not supported on the HTTP transport. diff --git a/docs/transports/requests.rst b/docs/transports/requests.rst index f920f3e0..15eaedb5 100644 --- a/docs/transports/requests.rst +++ b/docs/transports/requests.rst @@ -6,6 +6,8 @@ RequestsHTTPTransport The RequestsHTTPTransport is a sync transport using the `requests`_ library and allows you to send GraphQL queries using the HTTP protocol. +Reference: :py:class:`gql.transport.requests.RequestsHTTPTransport` + .. literalinclude:: ../code_examples/requests_sync.py .. _requests: https://requests.readthedocs.io diff --git a/docs/transports/websockets.rst b/docs/transports/websockets.rst index a082d887..7c91efb6 100644 --- a/docs/transports/websockets.rst +++ b/docs/transports/websockets.rst @@ -3,10 +3,17 @@ WebsocketsTransport =================== -The websockets transport implements the `Apollo websockets transport protocol`_. +The websockets transport supports both: + + - the `Apollo websockets transport protocol`_. + - the `GraphQL-ws websockets transport protocol`_ + +It will detect the backend supported protocol from the response http headers returned. This transport allows to do multiple queries, mutations and subscriptions on the same websocket connection. +Reference: :py:class:`gql.transport.websockets.WebsocketsTransport` + .. literalinclude:: ../code_examples/websockets_async.py Websockets SSL @@ -14,11 +21,11 @@ Websockets SSL If you need to connect to an ssl encrypted endpoint: -* use _wss_ instead of _ws_ in the url of the transport +* use :code:`wss` instead of :code:`ws` in the url of the transport .. code-block:: python - sample_transport = WebsocketsTransport( + transport = WebsocketsTransport( url='wss://SERVER_URL:SERVER_PORT/graphql', headers={'Authorization': 'token'} ) @@ -34,7 +41,7 @@ If you have a self-signed ssl certificate, you need to provide an ssl_context wi localhost_pem = pathlib.Path(__file__).with_name("YOUR_SERVER_PUBLIC_CERTIFICATE.pem") ssl_context.load_verify_locations(localhost_pem) - sample_transport = WebsocketsTransport( + transport = WebsocketsTransport( url='wss://SERVER_URL:SERVER_PORT/graphql', ssl=ssl_context ) @@ -54,7 +61,7 @@ There are two ways to send authentication tokens with websockets depending on th .. code-block:: python - sample_transport = WebsocketsTransport( + transport = WebsocketsTransport( url='wss://SERVER_URL:SERVER_PORT/graphql', headers={'Authorization': 'token'} ) @@ -63,9 +70,53 @@ There are two ways to send authentication tokens with websockets depending on th .. code-block:: python - sample_transport = WebsocketsTransport( + transport = WebsocketsTransport( url='wss://SERVER_URL:SERVER_PORT/graphql', init_payload={'Authorization': 'token'} ) +Keep-Alives +----------- + +Apollo protocol +^^^^^^^^^^^^^^^ + +With the Apollo protocol, the backend can optionally send unidirectional keep-alive ("ka") messages +(only from the server to the client). + +It is possible to configure the transport to close if we don't receive a "ka" message +within a specified time using the :code:`keep_alive_timeout` parameter. + +Here is an example with 60 seconds:: + + transport = WebsocketsTransport( + url='wss://SERVER_URL:SERVER_PORT/graphql', + keep_alive_timeout=60, + ) + +One disadvantage of the Apollo protocol is that because the keep-alives are only sent from the server +to the client, it can be difficult to detect the loss of a connection quickly from the server side. + +GraphQL-ws protocol +^^^^^^^^^^^^^^^^^^^ + +With the GraphQL-ws protocol, it is possible to send bidirectional ping/pong messages. +Pings can be sent either from the client or the server and the other party should answer with a pong. + +As with the Apollo protocol, it is possible to configure the transport to close if we don't +receive any message from the backend within the specified time using the :code:`keep_alive_timeout` parameter. + +But there is also the possibility for the client to send pings at a regular interval and verify +that the backend sends a pong within a specified delay. +This can be done using the :code:`ping_interval` and :code:`pong_timeout` parameters. + +Here is an example with a ping sent every 60 seconds, expecting a pong within 10 seconds:: + + transport = WebsocketsTransport( + url='wss://SERVER_URL:SERVER_PORT/graphql', + ping_interval=60, + pong_timeout=10, + ) + .. _Apollo websockets transport protocol: https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md +.. _GraphQL-ws websockets transport protocol: https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index 4ec8ce89..06552d2f 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -86,6 +86,11 @@ class WebsocketsTransport(AsyncTransport): on a websocket connection. """ + # This transport supports two subprotocols and will autodetect the + # subprotocol supported on the server + APOLLO_SUBPROTOCOL = cast(Subprotocol, "graphql-ws") + GRAPHQLWS_SUBPROTOCOL = cast(Subprotocol, "graphql-transport-ws") + def __init__( self, url: str, @@ -96,6 +101,9 @@ def __init__( close_timeout: Optional[Union[int, float]] = 10, ack_timeout: Optional[Union[int, float]] = 10, keep_alive_timeout: Optional[Union[int, float]] = None, + ping_interval: Optional[Union[int, float]] = None, + pong_timeout: Optional[Union[int, float]] = None, + answer_pings: bool = True, connect_args: Dict[str, Any] = {}, ) -> None: """Initialize the transport with the given parameters. @@ -112,8 +120,18 @@ def __init__( from the server. If None is provided this will wait forever. :param keep_alive_timeout: Optional Timeout in seconds to receive a sign of liveness from the server. + :param ping_interval: Delay in seconds between pings sent by the client to + the backend for the graphql-ws protocol. None (by default) means that + we don't send pings. + :param pong_timeout: Delay in seconds to receive a pong from the backend + after we sent a ping (only for the graphql-ws protocol). + By default equal to half of the ping_interval. + :param answer_pings: Whether the client answers the pings from the backend + (for the graphql-ws protocol). + By default: True :param connect_args: Other parameters forwarded to websockets.connect """ + self.url: str = url self.ssl: Union[SSLContext, bool] = ssl self.headers: Optional[HeadersLike] = headers @@ -123,6 +141,15 @@ def __init__( self.close_timeout: Optional[Union[int, float]] = close_timeout self.ack_timeout: Optional[Union[int, float]] = ack_timeout self.keep_alive_timeout: Optional[Union[int, float]] = keep_alive_timeout + self.ping_interval: Optional[Union[int, float]] = ping_interval + self.pong_timeout: Optional[Union[int, float]] + self.answer_pings: bool = answer_pings + + if ping_interval is not None: + if pong_timeout is None: + self.pong_timeout = ping_interval / 2 + else: + self.pong_timeout = pong_timeout self.connect_args = connect_args @@ -132,6 +159,7 @@ def __init__( self.receive_data_task: Optional[asyncio.Future] = None self.check_keep_alive_task: Optional[asyncio.Future] = None + self.send_ping_task: Optional[asyncio.Future] = None self.close_task: Optional[asyncio.Future] = None # We need to set an event loop here if there is none @@ -152,10 +180,28 @@ def __init__( self._next_keep_alive_message: asyncio.Event = asyncio.Event() self._next_keep_alive_message.set() + self.ping_received: asyncio.Event = asyncio.Event() + """ping_received is an asyncio Event which will fire each time + a ping is received with the graphql-ws protocol""" + + self.pong_received: asyncio.Event = asyncio.Event() + """pong_received is an asyncio Event which will fire each time + a pong is received with the graphql-ws protocol""" + + self.payloads: Dict[str, Any] = {} + """payloads is a dict which will contain the payloads received + with the graphql-ws protocol. + Possible keys are: 'ping', 'pong', 'connection_ack'""" + self._connecting: bool = False self.close_exception: Optional[Exception] = None + self.supported_subprotocols = [ + self.GRAPHQLWS_SUBPROTOCOL, + self.APOLLO_SUBPROTOCOL, + ] + async def _send(self, message: str) -> None: """Send the provided message to the websocket connection and log the message""" @@ -223,6 +269,28 @@ async def _send_init_message_and_wait_ack(self) -> None: # Wait for the connection_ack message or raise a TimeoutError await asyncio.wait_for(self._wait_ack(), self.ack_timeout) + async def send_ping(self, payload: Optional[Any] = None) -> None: + """Send a ping message for the graphql-ws protocol + """ + + ping_message = {"type": "ping"} + + if payload is not None: + ping_message["payload"] = payload + + await self._send(json.dumps(ping_message)) + + async def send_pong(self, payload: Optional[Any] = None) -> None: + """Send a pong message for the graphql-ws protocol + """ + + pong_message = {"type": "pong"} + + if payload is not None: + pong_message["payload"] = payload + + await self._send(json.dumps(pong_message)) + async def _send_stop_message(self, query_id: int) -> None: """Send stop message to the provided websocket connection and query_id. @@ -233,6 +301,32 @@ async def _send_stop_message(self, query_id: int) -> None: await self._send(stop_message) + async def _send_complete_message(self, query_id: int) -> None: + """Send a complete message for the provided query_id. + + This is only for the graphql-ws protocol. + """ + + complete_message = json.dumps({"id": str(query_id), "type": "complete"}) + + await self._send(complete_message) + + async def _stop_listener(self, query_id: int) -> None: + """Stop the listener corresponding to the query_id depending on the + detected backend protocol. + + For apollo: send a "stop" message + (a "complete" message will be sent from the backend) + + For graphql-ws: send a "complete" message and simulate the reception + of a "complete" message from the backend + """ + if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: + await self._send_complete_message(query_id) + await self.listeners[query_id].put(("complete", None)) + else: + await self._send_stop_message(query_id) + async def _send_connection_terminate_message(self) -> None: """Send a connection_terminate message to the provided websocket connection. @@ -265,18 +359,102 @@ async def _send_query( if operation_name: payload["operationName"] = operation_name + query_type = "start" + + if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: + query_type = "subscribe" + query_str = json.dumps( - {"id": str(query_id), "type": "start", "payload": payload} + {"id": str(query_id), "type": query_type, "payload": payload} ) await self._send(query_str) return query_id - def _parse_answer( + def _parse_answer_graphqlws( self, answer: str ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: - """Parse the answer received from the server + """Parse the answer received from the server if the server supports the + graphql-ws protocol. + + Returns a list consisting of: + - the answer_type (between: + 'connection_ack', 'ping', 'pong', 'data', 'error', 'complete') + - the answer id (Integer) if received or None + - an execution Result if the answer_type is 'data' or None + + Differences with the apollo websockets protocol (superclass): + - the "data" message is now called "next" + - the "stop" message is now called "complete" + - there is no connection_terminate or connection_error messages + - instead of a unidirectional keep-alive (ka) message from server to client, + there is now the possibility to send bidirectional ping/pong messages + - connection_ack has an optional payload + """ + + answer_type: str = "" + answer_id: Optional[int] = None + execution_result: Optional[ExecutionResult] = None + + try: + json_answer = json.loads(answer) + + answer_type = str(json_answer.get("type")) + + if answer_type in ["next", "error", "complete"]: + answer_id = int(str(json_answer.get("id"))) + + if answer_type == "next" or answer_type == "error": + + payload = json_answer.get("payload") + + if not isinstance(payload, dict): + raise ValueError("payload is not a dict") + + if answer_type == "next": + + if "errors" not in payload and "data" not in payload: + raise ValueError( + "payload does not contain 'data' or 'errors' fields" + ) + + execution_result = ExecutionResult( + errors=payload.get("errors"), + data=payload.get("data"), + extensions=payload.get("extensions"), + ) + + # Saving answer_type as 'data' to be understood with superclass + answer_type = "data" + + elif answer_type == "error": + + raise TransportQueryError( + str(payload), query_id=answer_id, errors=[payload] + ) + + elif answer_type in ["ping", "pong", "connection_ack"]: + self.payloads[answer_type] = json_answer.get("payload", None) + + else: + raise ValueError + + if self.check_keep_alive_task is not None: + self._next_keep_alive_message.set() + + except ValueError as e: + raise TransportProtocolError( + f"Server did not return a GraphQL result: {answer}" + ) from e + + return answer_type, answer_id, execution_result + + def _parse_answer_apollo( + self, answer: str + ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: + """Parse the answer received from the server if the server supports the + apollo websockets protocol. Returns a list consisting of: - the answer_type (between: @@ -342,6 +520,17 @@ def _parse_answer( return answer_type, answer_id, execution_result + def _parse_answer( + self, answer: str + ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: + """Parse the answer received from the server depending on + the detected subprotocol. + """ + if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: + return self._parse_answer_graphqlws(answer) + + return self._parse_answer_apollo(answer) + async def _check_ws_liveness(self) -> None: """Coroutine which will periodically check the liveness of the connection through keep-alive messages @@ -376,6 +565,39 @@ async def _check_ws_liveness(self) -> None: # The client is probably closing, handle it properly pass + async def _send_ping_coro(self) -> None: + """Coroutine to periodically send a ping from the client to the backend. + + Only used for the graphql-ws protocol. + + Send a ping every ping_interval seconds. + Close the connection if a pong is not received within pong_timeout seconds. + """ + + assert self.ping_interval is not None + + try: + while True: + await asyncio.sleep(self.ping_interval) + + await self.send_ping() + + await asyncio.wait_for(self.pong_received.wait(), self.pong_timeout) + + # Reset for the next iteration + self.pong_received.clear() + + except asyncio.TimeoutError: + # No pong received in the appriopriate time, close with error + # If the timeout happens during a close already in progress, do nothing + if self.close_task is None: + await self._fail( + TransportServerError( + f"No pong received after {self.pong_timeout!r} seconds" + ), + clean_close=False, + ) + async def _receive_data_loop(self) -> None: try: while True: @@ -428,6 +650,7 @@ async def _handle_answer( answer_id: Optional[int], execution_result: Optional[ExecutionResult], ) -> None: + try: # Put the answer in the queue if answer_id is not None: @@ -436,6 +659,15 @@ async def _handle_answer( # Do nothing if no one is listening to this query_id. pass + # Answer pong to ping for graphql-ws protocol + if answer_type == "ping": + self.ping_received.set() + if self.answer_pings: + await self.send_pong() + + elif answer_type == "pong": + self.pong_received.set() + async def subscribe( self, document: DocumentNode, @@ -486,7 +718,7 @@ async def subscribe( except (asyncio.CancelledError, GeneratorExit) as e: log.debug(f"Exception in subscribe: {e!r}") if listener.send_stop: - await self._send_stop_message(query_id) + await self._stop_listener(query_id) listener.send_stop = False finally: @@ -540,8 +772,6 @@ async def connect(self) -> None: Should be cleaned with a call to the close coroutine """ - GRAPHQLWS_SUBPROTOCOL: Subprotocol = cast(Subprotocol, "graphql-ws") - log.debug("connect: starting") if self.websocket is None and not self._connecting: @@ -562,7 +792,7 @@ async def connect(self) -> None: connect_args: Dict[str, Any] = { "ssl": ssl, "extra_headers": self.headers, - "subprotocols": [GRAPHQLWS_SUBPROTOCOL], + "subprotocols": self.supported_subprotocols, } # Adding custom parameters passed from init @@ -579,6 +809,19 @@ async def connect(self) -> None: finally: self._connecting = False + self.websocket = cast(WebSocketClientProtocol, self.websocket) + + # Find the backend subprotocol returned in the response headers + response_headers = self.websocket.response_headers + try: + self.subprotocol = response_headers["Sec-WebSocket-Protocol"] + except KeyError: + # If the server does not send the subprotocol header, using + # the apollo subprotocol by default + self.subprotocol = self.APOLLO_SUBPROTOCOL + + log.debug(f"backend subprotocol returned: {self.subprotocol!r}") + self.next_query_id = 1 self.close_exception = None self._wait_closed.clear() @@ -601,6 +844,14 @@ async def connect(self) -> None: self._check_ws_liveness() ) + # If requested, create a task to send periodic pings to the backend + if ( + self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL + and self.ping_interval is not None + ): + + self.send_ping_task = asyncio.ensure_future(self._send_ping_coro()) + # Create a task to listen to the incoming websocket messages self.receive_data_task = asyncio.ensure_future(self._receive_data_loop()) @@ -633,7 +884,7 @@ async def _clean_close(self, e: Exception) -> None: for query_id, listener in self.listeners.items(): if listener.send_stop: - await self._send_stop_message(query_id) + await self._stop_listener(query_id) listener.send_stop = False # Wait that there is no more listeners (we received 'complete' for all queries) @@ -642,8 +893,9 @@ async def _clean_close(self, e: Exception) -> None: except asyncio.TimeoutError: # pragma: no cover log.debug("Timer close_timeout fired") - # Finally send the 'connection_terminate' message - await self._send_connection_terminate_message() + if self.subprotocol == self.APOLLO_SUBPROTOCOL: + # Finally send the 'connection_terminate' message + await self._send_connection_terminate_message() async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: """Coroutine which will: @@ -669,6 +921,12 @@ async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: with suppress(asyncio.CancelledError): await self.check_keep_alive_task + # Properly shut down the send ping task if enabled + if self.send_ping_task is not None: + self.send_ping_task.cancel() + with suppress(asyncio.CancelledError): + await self.send_ping_task + # Saving exception to raise it later if trying to use the transport # after it has already closed. self.close_exception = e @@ -702,6 +960,7 @@ async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: self.websocket = None self.close_task = None self.check_keep_alive_task = None + self.send_ping_task = None self._wait_closed.set() diff --git a/tests/conftest.py b/tests/conftest.py index df69c121..004fa9df 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -128,13 +128,14 @@ class WebSocketServer: def __init__(self, with_ssl: bool = False): self.with_ssl = with_ssl - async def start(self, handler): + async def start(self, handler, extra_serve_args=None): import websockets.server print("Starting server") - extra_serve_args = {} + if extra_serve_args is None: + extra_serve_args = {} if self.with_ssl: # This is a copy of certificate from websockets tests folder @@ -192,7 +193,21 @@ async def send_keepalive(ws): await ws.send('{"type":"ka"}') @staticmethod - async def send_connection_ack(ws): + async def send_ping(ws, payload=None): + if payload is None: + await ws.send('{"type":"ping"}') + else: + await ws.send(json.dumps({"type": "ping", "payload": payload})) + + @staticmethod + async def send_pong(ws, payload=None): + if payload is None: + await ws.send('{"type":"pong"}') + else: + await ws.send(json.dumps({"type": "pong", "payload": payload})) + + @staticmethod + async def send_connection_ack(ws, payload=None): # Line return for easy debugging print("") @@ -203,7 +218,10 @@ async def send_connection_ack(ws): assert json_result["type"] == "connection_init" # Send ack - await ws.send('{"type":"connection_ack"}') + if payload is None: + await ws.send('{"type":"connection_ack"}') + else: + await ws.send(json.dumps({"type": "connection_ack", "payload": payload})) @staticmethod async def wait_connection_terminate(ws): @@ -352,6 +370,54 @@ async def server(request): await test_server.stop() +@pytest.fixture +async def graphqlws_server(request): + """Fixture used to start a dummy server with the graphql-ws protocol. + + Similar to the server fixture above but will return "graphql-transport-ws" + as the server subprotocol. + + It can take as argument either a handler function for the websocket server for + complete control OR an array of answers to be sent by the default server handler. + """ + + subprotocol = "graphql-transport-ws" + + from websockets.server import WebSocketServerProtocol + + class CustomSubprotocol(WebSocketServerProtocol): + def select_subprotocol(self, client_subprotocols, server_subprotocols): + print(f"Client subprotocols: {client_subprotocols!r}") + print(f"Server subprotocols: {server_subprotocols!r}") + + return subprotocol + + def process_subprotocol(self, headers, available_subprotocols): + # Overwriting available subprotocols + available_subprotocols = [subprotocol] + + print(f"headers: {headers!r}") + # print (f"Available subprotocols: {available_subprotocols!r}") + + return super().process_subprotocol(headers, available_subprotocols) + + server_handler = get_server_handler(request) + + try: + test_server = WebSocketServer() + + # Starting the server with the fixture param as the handler function + await test_server.start( + server_handler, extra_serve_args={"create_protocol": CustomSubprotocol} + ) + + yield test_server + except Exception as e: + print("Exception received in server fixture:", e) + finally: + await test_server.stop() + + @pytest.fixture async def client_and_server(server): """Helper fixture to start a server and a client connected to its port.""" @@ -369,6 +435,24 @@ async def client_and_server(server): yield session, server +@pytest.fixture +async def client_and_graphqlws_server(graphqlws_server): + """Helper fixture to start a server with the graphql-ws prototocol + and a client connected to its port.""" + + from gql.transport.websockets import WebsocketsTransport + + # Generate transport to connect to the server fixture + path = "/graphql" + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}" + sample_transport = WebsocketsTransport(url=url) + + async with Client(transport=sample_transport) as session: + + # Yield both client session and server + yield session, graphqlws_server + + @pytest.fixture async def run_sync_test(): async def run_sync_test_inner(event_loop, server, test_function): diff --git a/tests/test_graphqlws_exceptions.py b/tests/test_graphqlws_exceptions.py new file mode 100644 index 00000000..8a2e7495 --- /dev/null +++ b/tests/test_graphqlws_exceptions.py @@ -0,0 +1,289 @@ +import asyncio +import json +import types +from typing import List + +import pytest + +from gql import Client, gql +from gql.transport.exceptions import ( + TransportClosed, + TransportProtocolError, + TransportQueryError, +) + +from .conftest import WebSocketServerHelper + +# Marking all tests in this file with the websockets marker +pytestmark = pytest.mark.websockets + +invalid_query_str = """ + query getContinents { + continents { + code + bloh + } + } +""" + +invalid_query1_server_answer = ( + '{{"type":"next","id":"{query_id}",' + '"payload":{{"errors":[' + '{{"message":"Cannot query field \\"bloh\\" on type \\"Continent\\".",' + '"locations":[{{"line":4,"column":5}}],' + '"extensions":{{"code":"INTERNAL_SERVER_ERROR"}}}}]}}}}' +) + +invalid_query1_server = [invalid_query1_server_answer] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("graphqlws_server", [invalid_query1_server], indirect=True) +@pytest.mark.parametrize("query_str", [invalid_query_str]) +async def test_graphqlws_invalid_query( + event_loop, client_and_graphqlws_server, query_str +): + + session, server = client_and_graphqlws_server + + query = gql(query_str) + + with pytest.raises(TransportQueryError) as exc_info: + await session.execute(query) + + exception = exc_info.value + + assert isinstance(exception.errors, List) + + error = exception.errors[0] + + assert error["extensions"]["code"] == "INTERNAL_SERVER_ERROR" + + +invalid_subscription_str = """ + subscription getContinents { + continents { + code + bloh + } + } +""" + + +async def server_invalid_subscription(ws, path): + await WebSocketServerHelper.send_connection_ack(ws) + await ws.recv() + await ws.send(invalid_query1_server_answer.format(query_id=1)) + await WebSocketServerHelper.send_complete(ws, 1) + await ws.wait_closed() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "graphqlws_server", [server_invalid_subscription], indirect=True +) +@pytest.mark.parametrize("query_str", [invalid_subscription_str]) +async def test_graphqlws_invalid_subscription( + event_loop, client_and_graphqlws_server, query_str +): + + session, server = client_and_graphqlws_server + + query = gql(query_str) + + with pytest.raises(TransportQueryError) as exc_info: + async for result in session.subscribe(query): + pass + + exception = exc_info.value + + assert isinstance(exception.errors, List) + + error = exception.errors[0] + + assert error["extensions"]["code"] == "INTERNAL_SERVER_ERROR" + + +async def server_no_ack(ws, path): + await ws.wait_closed() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("graphqlws_server", [server_no_ack], indirect=True) +@pytest.mark.parametrize("query_str", [invalid_query_str]) +async def test_graphqlws_server_does_not_send_ack( + event_loop, graphqlws_server, query_str +): + from gql.transport.websockets import WebsocketsTransport + + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}/graphql" + + sample_transport = WebsocketsTransport(url=url, ack_timeout=1) + + with pytest.raises(asyncio.TimeoutError): + async with Client(transport=sample_transport): + pass + + +invalid_payload_server_answer = ( + '{"type":"error","id":"1","payload":{"message":"Must provide document"}}' +) + + +async def server_invalid_payload(ws, path): + await WebSocketServerHelper.send_connection_ack(ws) + result = await ws.recv() + print(f"Server received: {result}") + await ws.send(invalid_payload_server_answer) + await WebSocketServerHelper.wait_connection_terminate(ws) + await ws.wait_closed() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("graphqlws_server", [server_invalid_payload], indirect=True) +@pytest.mark.parametrize("query_str", [invalid_query_str]) +async def test_graphqlws_sending_invalid_payload( + event_loop, client_and_graphqlws_server, query_str +): + + session, server = client_and_graphqlws_server + + # Monkey patching the _send_query method to send an invalid payload + + async def monkey_patch_send_query( + self, document, variable_values=None, operation_name=None, + ) -> int: + query_id = self.next_query_id + self.next_query_id += 1 + + query_str = json.dumps( + {"id": str(query_id), "type": "subscribe", "payload": "BLAHBLAH"} + ) + + await self._send(query_str) + return query_id + + session.transport._send_query = types.MethodType( + monkey_patch_send_query, session.transport + ) + + query = gql(query_str) + + with pytest.raises(TransportQueryError) as exc_info: + await session.execute(query) + + exception = exc_info.value + + assert isinstance(exception.errors, List) + + error = exception.errors[0] + + assert error["message"] == "Must provide document" + + +not_json_answer = ["BLAHBLAH"] +missing_type_answer = ["{}"] +missing_id_answer_1 = ['{"type": "next"}'] +missing_id_answer_2 = ['{"type": "error"}'] +missing_id_answer_3 = ['{"type": "complete"}'] +data_without_payload = ['{"type": "next", "id":"1"}'] +error_without_payload = ['{"type": "error", "id":"1"}'] +payload_is_not_a_dict = ['{"type": "next", "id":"1", "payload": "BLAH"}'] +empty_payload = ['{"type": "next", "id":"1", "payload": {}}'] +sending_bytes = [b"\x01\x02\x03"] + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "graphqlws_server", + [ + not_json_answer, + missing_type_answer, + missing_id_answer_1, + missing_id_answer_2, + missing_id_answer_3, + data_without_payload, + error_without_payload, + payload_is_not_a_dict, + empty_payload, + sending_bytes, + ], + indirect=True, +) +async def test_graphqlws_transport_protocol_errors( + event_loop, client_and_graphqlws_server +): + + session, server = client_and_graphqlws_server + + query = gql("query { hello }") + + with pytest.raises(TransportProtocolError): + await session.execute(query) + + +async def server_without_ack(ws, path): + # Sending something else than an ack + await WebSocketServerHelper.send_complete(ws, 1) + await ws.wait_closed() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("graphqlws_server", [server_without_ack], indirect=True) +async def test_graphqlws_server_does_not_ack(event_loop, graphqlws_server): + from gql.transport.websockets import WebsocketsTransport + + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}/graphql" + print(f"url = {url}") + + sample_transport = WebsocketsTransport(url=url) + + with pytest.raises(TransportProtocolError): + async with Client(transport=sample_transport): + pass + + +async def server_closing_directly(ws, path): + await ws.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("graphqlws_server", [server_closing_directly], indirect=True) +async def test_graphqlws_server_closing_directly(event_loop, graphqlws_server): + import websockets + from gql.transport.websockets import WebsocketsTransport + + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}/graphql" + print(f"url = {url}") + + sample_transport = WebsocketsTransport(url=url) + + with pytest.raises(websockets.exceptions.ConnectionClosed): + async with Client(transport=sample_transport): + pass + + +async def server_closing_after_ack(ws, path): + await WebSocketServerHelper.send_connection_ack(ws) + await ws.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("graphqlws_server", [server_closing_after_ack], indirect=True) +async def test_graphqlws_server_closing_after_ack( + event_loop, client_and_graphqlws_server +): + + import websockets + + session, server = client_and_graphqlws_server + + query = gql("query { hello }") + + with pytest.raises(websockets.exceptions.ConnectionClosed): + await session.execute(query) + + await session.transport.wait_closed() + + with pytest.raises(TransportClosed): + await session.execute(query) diff --git a/tests/test_graphqlws_subscription.py b/tests/test_graphqlws_subscription.py new file mode 100644 index 00000000..8f38d101 --- /dev/null +++ b/tests/test_graphqlws_subscription.py @@ -0,0 +1,745 @@ +import asyncio +import json +import sys +from typing import List + +import pytest +from parse import search + +from gql import Client, gql +from gql.transport.exceptions import TransportServerError + +from .conftest import MS, WebSocketServerHelper + +# Marking all tests in this file with the websockets marker +pytestmark = pytest.mark.websockets + +countdown_server_answer = ( + '{{"type":"next","id":"{query_id}","payload":{{"data":{{"number":{number}}}}}}}' +) + +COUNTING_DELAY = 2 * MS +PING_SENDING_DELAY = 5 * MS +PONG_TIMEOUT = 2 * MS + +# List which can used to store received messages by the server +logged_messages: List[str] = [] + + +def server_countdown_factory(keepalive=False, answer_pings=True): + async def server_countdown_template(ws, path): + import websockets + + logged_messages.clear() + + try: + await WebSocketServerHelper.send_connection_ack( + ws, payload="dummy_connection_ack_payload" + ) + + result = await ws.recv() + logged_messages.append(result) + + json_result = json.loads(result) + assert json_result["type"] == "subscribe" + payload = json_result["payload"] + query = payload["query"] + query_id = json_result["id"] + + count_found = search("count: {:d}", query) + count = count_found[0] + print(f"Countdown started from: {count}") + + pong_received: asyncio.Event = asyncio.Event() + + async def counting_coro(): + for number in range(count, -1, -1): + await ws.send( + countdown_server_answer.format(query_id=query_id, number=number) + ) + await asyncio.sleep(COUNTING_DELAY) + + counting_task = asyncio.ensure_future(counting_coro()) + + async def keepalive_coro(): + while True: + await asyncio.sleep(PING_SENDING_DELAY) + try: + # Send a ping + await WebSocketServerHelper.send_ping( + ws, payload="dummy_ping_payload" + ) + + # Wait for a pong + try: + await asyncio.wait_for(pong_received.wait(), PONG_TIMEOUT) + except asyncio.TimeoutError: + print("\nNo pong received in time!\n") + break + + pong_received.clear() + + except websockets.exceptions.ConnectionClosed: + break + + if keepalive: + keepalive_task = asyncio.ensure_future(keepalive_coro()) + + async def receiving_coro(): + nonlocal counting_task + while True: + + try: + result = await ws.recv() + logged_messages.append(result) + except websockets.exceptions.ConnectionClosed: + break + + json_result = json.loads(result) + + answer_type = json_result["type"] + + if answer_type == "complete" and json_result["id"] == str(query_id): + print("Cancelling counting task now") + counting_task.cancel() + if keepalive: + print("Cancelling keep alive task now") + keepalive_task.cancel() + + elif answer_type == "ping": + if answer_pings: + payload = json_result.get("payload", None) + await WebSocketServerHelper.send_pong(ws, payload=payload) + + elif answer_type == "pong": + pong_received.set() + + receiving_task = asyncio.ensure_future(receiving_coro()) + + try: + await counting_task + except asyncio.CancelledError: + print("Now counting task is cancelled") + + receiving_task.cancel() + + try: + await receiving_task + except asyncio.CancelledError: + print("Now receiving task is cancelled") + + if keepalive: + keepalive_task.cancel() + try: + await keepalive_task + except asyncio.CancelledError: + print("Now keepalive task is cancelled") + + await WebSocketServerHelper.send_complete(ws, query_id) + except websockets.exceptions.ConnectionClosedOK: + pass + except AssertionError as e: + print(f"\nAssertion failed: {e!s}\n") + finally: + await ws.wait_closed() + + return server_countdown_template + + +async def server_countdown(ws, path): + + server = server_countdown_factory() + await server(ws, path) + + +async def server_countdown_keepalive(ws, path): + + server = server_countdown_factory(keepalive=True) + await server(ws, path) + + +async def server_countdown_dont_answer_pings(ws, path): + + server = server_countdown_factory(answer_pings=False) + await server(ws, path) + + +countdown_subscription_str = """ + subscription {{ + countdown (count: {count}) {{ + number + }} + }} +""" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("graphqlws_server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_graphqlws_subscription( + event_loop, client_and_graphqlws_server, subscription_str +): + + session, server = client_and_graphqlws_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("graphqlws_server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_graphqlws_subscription_break( + event_loop, client_and_graphqlws_server, subscription_str +): + + session, server = client_and_graphqlws_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + + if count <= 5: + # Note: the following line is only necessary for pypy3 v3.6.1 + if sys.version_info < (3, 7): + await session._generator.aclose() + break + + count -= 1 + + assert count == 5 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("graphqlws_server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_graphqlws_subscription_task_cancel( + event_loop, client_and_graphqlws_server, subscription_str +): + + session, server = client_and_graphqlws_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async def task_coro(): + nonlocal count + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + + count -= 1 + + task = asyncio.ensure_future(task_coro()) + + async def cancel_task_coro(): + nonlocal task + + await asyncio.sleep(5.5 * COUNTING_DELAY) + + task.cancel() + + cancel_task = asyncio.ensure_future(cancel_task_coro()) + + await asyncio.gather(task, cancel_task) + + assert count > 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("graphqlws_server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_graphqlws_subscription_close_transport( + event_loop, client_and_graphqlws_server, subscription_str +): + + session, server = client_and_graphqlws_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async def task_coro(): + nonlocal count + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + + count -= 1 + + task = asyncio.ensure_future(task_coro()) + + async def close_transport_task_coro(): + nonlocal task + + await asyncio.sleep(5.5 * COUNTING_DELAY) + + await session.transport.close() + + close_transport_task = asyncio.ensure_future(close_transport_task_coro()) + + await asyncio.gather(task, close_transport_task) + + assert count > 0 + + +async def server_countdown_close_connection_in_middle(ws, path): + await WebSocketServerHelper.send_connection_ack(ws) + + result = await ws.recv() + json_result = json.loads(result) + assert json_result["type"] == "subscribe" + payload = json_result["payload"] + query = payload["query"] + query_id = json_result["id"] + + count_found = search("count: {:d}", query) + count = count_found[0] + stopping_before = count // 2 + print(f"Countdown started from: {count}, stopping server before {stopping_before}") + for number in range(count, stopping_before, -1): + await ws.send(countdown_server_answer.format(query_id=query_id, number=number)) + await asyncio.sleep(COUNTING_DELAY) + + print("Closing server while subscription is still running now") + await ws.close() + await ws.wait_closed() + print("Server is now closed") + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "graphqlws_server", [server_countdown_close_connection_in_middle], indirect=True +) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_graphqlws_subscription_server_connection_closed( + event_loop, client_and_graphqlws_server, subscription_str +): + import websockets + + session, server = client_and_graphqlws_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + with pytest.raises(websockets.exceptions.ConnectionClosedOK): + + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + + count -= 1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("graphqlws_server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_graphqlws_subscription_with_operation_name( + event_loop, client_and_graphqlws_server, subscription_str +): + + session, server = client_and_graphqlws_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async for result in session.subscribe( + subscription, operation_name="CountdownSubscription" + ): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + + # Check that the query contains the operationName + assert '"operationName": "CountdownSubscription"' in logged_messages[0] + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "graphqlws_server", [server_countdown_keepalive], indirect=True +) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_graphqlws_subscription_with_keepalive( + event_loop, client_and_graphqlws_server, subscription_str +): + + session, server = client_and_graphqlws_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + assert session.transport.payloads["ping"] == "dummy_ping_payload" + assert ( + session.transport.payloads["connection_ack"] == "dummy_connection_ack_payload" + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "graphqlws_server", [server_countdown_keepalive], indirect=True +) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_graphqlws_subscription_with_keepalive_with_timeout_ok( + event_loop, graphqlws_server, subscription_str +): + + from gql.transport.websockets import WebsocketsTransport + + path = "/graphql" + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}" + transport = WebsocketsTransport(url=url, keep_alive_timeout=(5 * COUNTING_DELAY)) + + client = Client(transport=transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async with client as session: + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "graphqlws_server", [server_countdown_keepalive], indirect=True +) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_graphqlws_subscription_with_keepalive_with_timeout_nok( + event_loop, graphqlws_server, subscription_str +): + + from gql.transport.websockets import WebsocketsTransport + + path = "/graphql" + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}" + transport = WebsocketsTransport(url=url, keep_alive_timeout=(COUNTING_DELAY / 2)) + + client = Client(transport=transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async with client as session: + with pytest.raises(TransportServerError) as exc_info: + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert "No keep-alive message has been received" in str(exc_info.value) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "graphqlws_server", [server_countdown_keepalive], indirect=True +) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_graphqlws_subscription_with_ping_interval_ok( + event_loop, graphqlws_server, subscription_str +): + + from gql.transport.websockets import WebsocketsTransport + + path = "/graphql" + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}" + transport = WebsocketsTransport( + url=url, ping_interval=(5 * COUNTING_DELAY), pong_timeout=(4 * COUNTING_DELAY), + ) + + client = Client(transport=transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async with client as session: + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "graphqlws_server", [server_countdown_dont_answer_pings], indirect=True +) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_graphqlws_subscription_with_ping_interval_nok( + event_loop, graphqlws_server, subscription_str +): + + from gql.transport.websockets import WebsocketsTransport + + path = "/graphql" + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}" + transport = WebsocketsTransport(url=url, ping_interval=(5 * COUNTING_DELAY)) + + client = Client(transport=transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async with client as session: + with pytest.raises(TransportServerError) as exc_info: + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert "No pong received" in str(exc_info.value) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "graphqlws_server", [server_countdown_keepalive], indirect=True +) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_graphqlws_subscription_manual_pings_with_payload( + event_loop, graphqlws_server, subscription_str +): + + from gql.transport.websockets import WebsocketsTransport + + path = "/graphql" + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}" + transport = WebsocketsTransport(url=url) + + client = Client(transport=transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async with client as session: + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + payload = {"count_received": count} + + await transport.send_ping(payload=payload) + + await transport.pong_received.wait() + transport.pong_received.clear() + + assert transport.payloads["pong"] == payload + + assert count == -1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "graphqlws_server", [server_countdown_keepalive], indirect=True +) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_graphqlws_subscription_manual_pong_answers_with_payload( + event_loop, graphqlws_server, subscription_str +): + + from gql.transport.websockets import WebsocketsTransport + + path = "/graphql" + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}" + transport = WebsocketsTransport(url=url, answer_pings=False) + + client = Client(transport=transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async with client as session: + + async def answer_ping_coro(): + while True: + await transport.ping_received.wait() + transport.ping_received.clear() + await transport.send_pong(payload={"some": "data"}) + + answer_ping_task = asyncio.ensure_future(answer_ping_coro()) + + try: + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + finally: + answer_ping_task.cancel() + + assert count == -1 + + +@pytest.mark.parametrize( + "graphqlws_server", [server_countdown_keepalive], indirect=True +) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +def test_graphqlws_subscription_sync(graphqlws_server, subscription_str): + from gql.transport.websockets import WebsocketsTransport + + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}/graphql" + print(f"url = {url}") + + transport = WebsocketsTransport(url=url) + + client = Client(transport=transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + for result in client.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + + +@pytest.mark.skipif(sys.platform.startswith("win"), reason="test failing on windows") +@pytest.mark.parametrize( + "graphqlws_server", [server_countdown_keepalive], indirect=True +) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +def test_graphqlws_subscription_sync_graceful_shutdown( + graphqlws_server, subscription_str +): + """ Note: this test will simulate a control-C happening while a sync subscription + is in progress. To do that we will throw a KeyboardInterrupt exception inside + the subscription async generator. + + The code should then do a clean close: + - send stop messages for each active query + - send a connection_terminate message + Then the KeyboardInterrupt will be reraise (to warn potential user code) + + This test does not work on Windows but the behaviour with Windows is correct. + """ + from gql.transport.websockets import WebsocketsTransport + + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}/graphql" + print(f"url = {url}") + + transport = WebsocketsTransport(url=url) + + client = Client(transport=transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + with pytest.raises(KeyboardInterrupt): + for result in client.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + + if count == 5: + + # Simulate a KeyboardInterrupt in the generator + asyncio.ensure_future( + client.session._generator.athrow(KeyboardInterrupt) + ) + + count -= 1 + + assert count == 4 + + # Check that the server received a connection_terminate message last + # assert logged_messages.pop() == '{"type": "connection_terminate"}' + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "graphqlws_server", [server_countdown_keepalive], indirect=True +) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_graphqlws_subscription_running_in_thread( + event_loop, graphqlws_server, subscription_str, run_sync_test +): + from gql.transport.websockets import WebsocketsTransport + + def test_code(): + path = "/graphql" + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}" + transport = WebsocketsTransport(url=url) + + client = Client(transport=transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + for result in client.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + + await run_sync_test(event_loop, graphqlws_server, test_code) diff --git a/tests/test_phoenix_channel_subscription.py b/tests/test_phoenix_channel_subscription.py index 3c6ec2b2..6367945d 100644 --- a/tests/test_phoenix_channel_subscription.py +++ b/tests/test_phoenix_channel_subscription.py @@ -235,6 +235,8 @@ async def test_phoenix_channel_subscription_no_break( from gql.transport.phoenix_channel_websockets import log as phoenix_logger from gql.transport.websockets import log as websockets_logger + from .conftest import MS + websockets_logger.setLevel(logging.DEBUG) phoenix_logger.setLevel(logging.DEBUG) @@ -244,7 +246,7 @@ async def test_phoenix_channel_subscription_no_break( async def testing_stopping_without_break(): sample_transport = PhoenixChannelWebsocketsTransport( - channel_name=test_channel, url=url, close_timeout=5 + channel_name=test_channel, url=url, close_timeout=(5000 * MS) ) count = 10 @@ -256,7 +258,8 @@ async def testing_stopping_without_break(): print(f"Number received: {number}") # Simulate a slow consumer - await asyncio.sleep(0.1) + if number == 10: + await asyncio.sleep(50 * MS) if number == 9: # When we consume the number 9 here in the async generator, @@ -274,7 +277,7 @@ async def testing_stopping_without_break(): assert count == -1 try: - await asyncio.wait_for(testing_stopping_without_break(), timeout=5) + await asyncio.wait_for(testing_stopping_without_break(), timeout=(5000 * MS)) except asyncio.TimeoutError: assert False, "The async generator did not stop" diff --git a/tests/test_websocket_subscription.py b/tests/test_websocket_subscription.py index 7d87ee81..d5167720 100644 --- a/tests/test_websocket_subscription.py +++ b/tests/test_websocket_subscription.py @@ -391,7 +391,7 @@ async def test_websocket_subscription_with_keepalive_with_timeout_ok( path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" - sample_transport = WebsocketsTransport(url=url, keep_alive_timeout=(500 * MS)) + sample_transport = WebsocketsTransport(url=url, keep_alive_timeout=(10 * MS)) client = Client(transport=sample_transport) diff --git a/tox.ini b/tox.ini index 19231ed7..414f083b 100644 --- a/tox.ini +++ b/tox.ini @@ -21,6 +21,7 @@ setenv = PYTHONPATH = {toxinidir} MULTIDICT_NO_EXTENSIONS = 1 ; Related to https://github.com/aio-libs/multidict YARL_NO_EXTENSIONS = 1 ; Related to https://github.com/aio-libs/yarl + GQL_TESTS_TIMEOUT_FACTOR = 10 install_command = python -m pip install --ignore-installed {opts} {packages} whitelist_externals = python