diff --git a/gql/transport/phoenix_channel_websockets.py b/gql/transport/phoenix_channel_websockets.py index 27e58f2a..56d35f8b 100644 --- a/gql/transport/phoenix_channel_websockets.py +++ b/gql/transport/phoenix_channel_websockets.py @@ -1,5 +1,6 @@ import asyncio import json +import logging from typing import Any, Dict, Optional, Tuple from graphql import DocumentNode, ExecutionResult, print_ast @@ -12,6 +13,16 @@ ) from .websockets import WebsocketsTransport +log = logging.getLogger(__name__) + + +class Subscription: + """Records listener_id and unsubscribe query_id for a subscription.""" + + def __init__(self, query_id: int) -> None: + self.listener_id: int = query_id + self.unsubscribe_id: Optional[int] = None + class PhoenixChannelWebsocketsTransport(WebsocketsTransport): """The PhoenixChannelWebsocketsTransport is an **EXPERIMENTAL** async transport @@ -24,17 +35,23 @@ class PhoenixChannelWebsocketsTransport(WebsocketsTransport): """ def __init__( - self, channel_name: str, heartbeat_interval: float = 30, *args, **kwargs + self, + channel_name: str = "__absinthe__:control", + heartbeat_interval: float = 30, + *args, + **kwargs, ) -> None: """Initialize the transport with the given parameters. - :param channel_name: Channel on the server this transport will join + :param channel_name: Channel on the server this transport will join. + The default for Absinthe servers is "__absinthe__:control" :param heartbeat_interval: Interval in second between each heartbeat messages sent by the client """ - self.channel_name = channel_name - self.heartbeat_interval = heartbeat_interval - self.subscription_ids_to_query_ids: Dict[str, int] = {} + self.channel_name: str = channel_name + self.heartbeat_interval: float = heartbeat_interval + self.heartbeat_task: Optional[asyncio.Future] = None + self.subscriptions: Dict[str, Subscription] = {} super(PhoenixChannelWebsocketsTransport, self).__init__(*args, **kwargs) async def _send_init_message_and_wait_ack(self) -> None: @@ -90,14 +107,32 @@ async def heartbeat_coro(): self.heartbeat_task = asyncio.ensure_future(heartbeat_coro()) async def _send_stop_message(self, query_id: int) -> None: - try: - await self.listeners[query_id].put(("complete", None)) - except KeyError: # pragma: no cover - pass + """Send an 'unsubscribe' message to the Phoenix Channel referencing + the listener's query_id, saving the query_id of the message. - async def _send_connection_terminate_message(self) -> None: - """Send a phx_leave message to disconnect from the provided channel. + The server should afterwards return a 'phx_reply' message with + the same query_id and subscription_id of the 'unsubscribe' request. """ + subscription_id = self._find_existing_subscription(query_id) + + unsubscribe_query_id = self.next_query_id + self.next_query_id += 1 + + # Save the ref so it can be matched in the reply + self.subscriptions[subscription_id].unsubscribe_id = unsubscribe_query_id + unsubscribe_message = json.dumps( + { + "topic": self.channel_name, + "event": "unsubscribe", + "payload": {"subscriptionId": subscription_id}, + "ref": unsubscribe_query_id, + } + ) + + await self._send(unsubscribe_message) + + async def _send_connection_terminate_message(self) -> None: + """Send a phx_leave message to disconnect from the provided channel.""" query_id = self.next_query_id self.next_query_id += 1 @@ -152,7 +187,7 @@ def _parse_answer( Returns a list consisting of: - the answer_type (between: - 'heartbeat', 'data', 'reply', 'error', 'close') + 'data', 'reply', 'complete', 'close') - the answer id (Integer) if received or None - an execution Result if the answer_type is 'data' or None """ @@ -161,56 +196,129 @@ def _parse_answer( answer_id: Optional[int] = None answer_type: str = "" execution_result: Optional[ExecutionResult] = None + subscription_id: Optional[str] = None + + def _get_value(d: Any, key: str, label: str) -> Any: + if not isinstance(d, dict): + raise ValueError(f"{label} is not a dict") + + return d.get(key) + + def _required_value(d: Any, key: str, label: str) -> Any: + value = _get_value(d, key, label) + if value is None: + raise ValueError(f"null {key} in {label}") + + return value + + def _required_subscription_id( + d: Any, label: str, must_exist: bool = False, must_not_exist=False + ) -> str: + subscription_id = str(_required_value(d, "subscriptionId", label)) + if must_exist and (subscription_id not in self.subscriptions): + raise ValueError("unregistered subscriptionId") + if must_not_exist and (subscription_id in self.subscriptions): + raise ValueError("previously registered subscriptionId") + + return subscription_id + + def _validate_data_response(d: Any, label: str) -> dict: + """Make sure query, mutation or subscription answer conforms. + The GraphQL spec says only three keys are permitted. + """ + if not isinstance(d, dict): + raise ValueError(f"{label} is not a dict") + + keys = set(d.keys()) + invalid = keys - {"data", "errors", "extensions"} + if len(invalid) > 0: + raise ValueError( + f"{label} contains invalid items: " + ", ".join(invalid) + ) + return d try: json_answer = json.loads(answer) - event = str(json_answer.get("event")) + event = str(_required_value(json_answer, "event", "answer")) if event == "subscription:data": - payload = json_answer.get("payload") + payload = _required_value(json_answer, "payload", "answer") - if not isinstance(payload, dict): - raise ValueError("payload is not a dict") - - subscription_id = str(payload.get("subscriptionId")) - try: - answer_id = self.subscription_ids_to_query_ids[subscription_id] - except KeyError: - raise ValueError( - f"subscription '{subscription_id}' has not been registerd" - ) - - result = payload.get("result") + subscription_id = _required_subscription_id( + payload, "payload", must_exist=True + ) - if not isinstance(result, dict): - raise ValueError("result is not a dict") + result = _validate_data_response(payload.get("result"), "result") answer_type = "data" + subscription = self.subscriptions[subscription_id] + answer_id = subscription.listener_id + execution_result = ExecutionResult( - errors=payload.get("errors"), data=result.get("data"), - extensions=payload.get("extensions"), + errors=result.get("errors"), + extensions=result.get("extensions"), ) elif event == "phx_reply": - answer_id = int(json_answer.get("ref")) - payload = json_answer.get("payload") - if not isinstance(payload, dict): - raise ValueError("payload is not a dict") + # Will generate a ValueError if 'ref' is not there + # or if it is not an integer + answer_id = int(_required_value(json_answer, "ref", "answer")) - status = str(payload.get("status")) + payload = _required_value(json_answer, "payload", "answer") - if status == "ok": + status = _get_value(payload, "status", "payload") + if status == "ok": answer_type = "reply" - response = payload.get("response") - if isinstance(response, dict) and "subscriptionId" in response: - subscription_id = str(response.get("subscriptionId")) - self.subscription_ids_to_query_ids[subscription_id] = answer_id + if answer_id in self.listeners: + response = _required_value(payload, "response", "payload") + + if isinstance(response, dict) and "subscriptionId" in response: + + # Subscription answer + subscription_id = _required_subscription_id( + response, "response", must_not_exist=True + ) + + self.subscriptions[subscription_id] = Subscription( + answer_id + ) + + else: + # Query or mutation answer + # GraphQL spec says only three keys are permitted + response = _validate_data_response(response, "response") + + answer_type = "data" + + execution_result = ExecutionResult( + data=response.get("data"), + errors=response.get("errors"), + extensions=response.get("extensions"), + ) + else: + ( + registered_subscription_id, + listener_id, + ) = self._find_subscription(answer_id) + if registered_subscription_id is not None: + # Unsubscription answer + response = _required_value(payload, "response", "payload") + subscription_id = _required_subscription_id( + response, "response" + ) + + if subscription_id != registered_subscription_id: + raise ValueError("subscription id does not match") + + answer_type = "complete" + + answer_id = listener_id elif status == "error": response = payload.get("response") @@ -224,21 +332,28 @@ def _parse_answer( raise TransportQueryError( str(response.get("reason")), query_id=answer_id ) - raise ValueError("reply error") + raise TransportQueryError("reply error", query_id=answer_id) elif status == "timeout": raise TransportQueryError("reply timeout", query_id=answer_id) + else: + # missing or unrecognized status, just continue + pass elif event == "phx_error": + # Sent if the channel has crashed + # answer_id will be the "join_ref" for the channel + # answer_id = int(json_answer.get("ref")) raise TransportServerError("Server error") elif event == "phx_close": answer_type = "close" else: - raise ValueError + raise ValueError("unrecognized event") except ValueError as e: + log.error(f"Error parsing answer '{answer}': {e!r}") raise TransportProtocolError( - "Server did not return a GraphQL result" + f"Server did not return a GraphQL result: {e!s}" ) from e return answer_type, answer_id, execution_result @@ -254,6 +369,38 @@ async def _handle_answer( else: await super()._handle_answer(answer_type, answer_id, execution_result) + def _remove_listener(self, query_id: int) -> None: + """If the listener was a subscription, remove that information.""" + try: + subscription_id = self._find_existing_subscription(query_id) + del self.subscriptions[subscription_id] + except Exception: + pass + super()._remove_listener(query_id) + + def _find_subscription(self, query_id: int) -> Tuple[Optional[str], int]: + """Perform a reverse lookup to find the subscription id matching + a listener's query_id. + """ + for subscription_id, subscription in self.subscriptions.items(): + if query_id == subscription.listener_id: + return subscription_id, query_id + if query_id == subscription.unsubscribe_id: + return subscription_id, subscription.listener_id + return None, query_id + + def _find_existing_subscription(self, query_id: int) -> str: + """Perform a reverse lookup to find the subscription id matching + a listener's query_id. + """ + subscription_id, _listener_id = self._find_subscription(query_id) + + if subscription_id is None: + raise TransportProtocolError( + f"No subscription registered for listener {query_id}" + ) + return subscription_id + async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: if self.heartbeat_task is not None: self.heartbeat_task.cancel() diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index 7e26f31c..50eeb6b0 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -175,8 +175,9 @@ async def _receive(self) -> str: """Wait the next message from the websocket connection and log the answer """ - # We should always have an active websocket connection here - assert self.websocket is not None + # It is possible that the websocket has been already closed in another task + if self.websocket is None: + raise TransportClosed("Transport is already closed") # Wait for the next websocket frame. Can raise ConnectionClosed data: Data = await self.websocket.recv() @@ -387,6 +388,8 @@ async def _receive_data_loop(self) -> None: except (ConnectionClosed, TransportProtocolError) as e: await self._fail(e, clean_close=False) break + except TransportClosed: + break # Parse the answer try: @@ -483,15 +486,14 @@ async def subscribe( break except (asyncio.CancelledError, GeneratorExit) as e: - log.debug("Exception in subscribe: " + repr(e)) + log.debug(f"Exception in subscribe: {e!r}") if listener.send_stop: await self._send_stop_message(query_id) listener.send_stop = False finally: - del self.listeners[query_id] - if len(self.listeners) == 0: - self._no_more_listeners.set() + log.debug(f"In subscribe finally for query_id {query_id}") + self._remove_listener(query_id) async def execute( self, @@ -609,6 +611,19 @@ async def connect(self) -> None: log.debug("connect: done") + def _remove_listener(self, query_id) -> None: + """After exiting from a subscription, remove the listener and + signal an event if this was the last listener for the client. + """ + if query_id in self.listeners: + del self.listeners[query_id] + + remaining = len(self.listeners) + log.debug(f"listener {query_id} deleted, {remaining} remaining") + + if remaining == 0: + self._no_more_listeners.set() + async def _clean_close(self, e: Exception) -> None: """Coroutine which will: @@ -627,7 +642,7 @@ async def _clean_close(self, e: Exception) -> None: try: await asyncio.wait_for(self._no_more_listeners.wait(), self.close_timeout) except asyncio.TimeoutError: # pragma: no cover - pass + log.debug("Timer close_timeout fired") # Finally send the 'connection_terminate' message await self._send_connection_terminate_message() diff --git a/tests/conftest.py b/tests/conftest.py index 62f107ac..df69c121 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -100,7 +100,12 @@ async def go(app, *, port=None, **kwargs): # type: ignore # Adding debug logs to websocket tests -for name in ["websockets.legacy.server", "gql.transport.websockets", "gql.dsl"]: +for name in [ + "websockets.legacy.server", + "gql.transport.websockets", + "gql.transport.phoenix_channel_websockets", + "gql.dsl", +]: logger = logging.getLogger(name) logger.setLevel(logging.DEBUG) @@ -170,7 +175,7 @@ async def stop(self): self.server.close() try: - await asyncio.wait_for(self.server.wait_closed(), timeout=1) + await asyncio.wait_for(self.server.wait_closed(), timeout=5) except asyncio.TimeoutError: # pragma: no cover assert False, "Server failed to stop" diff --git a/tests/test_phoenix_channel_exceptions.py b/tests/test_phoenix_channel_exceptions.py index 6f066325..1711d25a 100644 --- a/tests/test_phoenix_channel_exceptions.py +++ b/tests/test_phoenix_channel_exceptions.py @@ -1,3 +1,5 @@ +import asyncio + import pytest from gql import Client, gql @@ -7,9 +9,22 @@ TransportServerError, ) +from .conftest import MS + # Marking all tests in this file with the websockets marker pytestmark = pytest.mark.websockets + +def ensure_list(s): + return ( + s + if s is None or isinstance(s, list) + else list(s) + if isinstance(s, tuple) + else [s] + ) + + query1_str = """ query getContinents { continents { @@ -19,21 +34,48 @@ } """ -default_subscription_server_answer = ( +default_query_server_answer = ( '{"event":"phx_reply",' '"payload":' '{"response":' - '{"subscriptionId":"test_subscription"},' + '{"data":{"continents":[' + '{"code":"AF","name":"Africa"},{"code":"AN","name":"Antarctica"},' + '{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},' + '{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},' + '{"code":"SA","name":"South America"}]}},' '"status":"ok"},' '"ref":2,' '"topic":"test_topic"}' ) + +# other protocol exceptions + +reply_ref_null_answer = ( + '{"event":"phx_reply","payload":{}', + '"ref":null,' '"topic":"test_topic"}', +) + +reply_ref_zero_answer = ( + '{"event":"phx_reply","payload":{}', + '"ref":0,' '"topic":"test_topic"}', +) + + +# "status":"error" responses + +generic_error_server_answer = ( + '{"event":"phx_reply",' + '"payload":' + '{"status":"error"},' + '"ref":2,' + '"topic":"test_topic"}' +) + error_with_reason_server_answer = ( '{"event":"phx_reply",' '"payload":' - '{"response":' - '{"reason":"internal error"},' + '{"response":{"reason":"internal error"},' '"status":"error"},' '"ref":2,' '"topic":"test_topic"}' @@ -42,8 +84,7 @@ multiple_errors_server_answer = ( '{"event":"phx_reply",' '"payload":' - '{"response":' - '{"errors": ["error 1", "error 2"]},' + '{"response":{"errors": ["error 1", "error 2"]},' '"status":"error"},' '"ref":2,' '"topic":"test_topic"}' @@ -57,31 +98,95 @@ '"topic":"test_topic"}' ) +invalid_payload_data_answer = ( + '{"event":"phx_reply",' '"payload":"INVALID",' '"ref":2,' '"topic":"test_topic"}' +) + +# "status":"ok" exceptions -def server( - query_server_answer, subscription_server_answer=default_subscription_server_answer, -): +invalid_response_server_answer = ( + '{"event":"phx_reply",' + '"payload":{"response":"INVALID",' + '"status":"ok"}' + '"ref":2,' + '"topic":"test_topic"}' +) + +invalid_response_keys_server_answer = ( + '{"event":"phx_reply",' + '"payload":{"response":' + '{"data":{"continents":null},"invalid":null}",' + '"status":"ok"}' + '"ref":2,' + '"topic":"test_topic"}' +) + +invalid_event_server_answer = '{"event":"unknown"}' + + +def query_server(server_answers=default_query_server_answer): from .conftest import PhoenixChannelServerHelper async def phoenix_server(ws, path): await PhoenixChannelServerHelper.send_connection_ack(ws) await ws.recv() - await ws.send(subscription_server_answer) - if query_server_answer is not None: - await ws.send(query_server_answer) + for server_answer in ensure_list(server_answers): + await ws.send(server_answer) await PhoenixChannelServerHelper.send_close(ws) await ws.wait_closed() return phoenix_server +async def no_connection_ack_phoenix_server(ws, path): + from .conftest import PhoenixChannelServerHelper + + await ws.recv() + await PhoenixChannelServerHelper.send_close(ws) + await ws.wait_closed() + + @pytest.mark.asyncio @pytest.mark.parametrize( "server", [ - server(error_with_reason_server_answer), - server(multiple_errors_server_answer), - server(timeout_server_answer), + query_server(reply_ref_null_answer), + query_server(reply_ref_zero_answer), + query_server(invalid_payload_data_answer), + query_server(invalid_response_server_answer), + query_server(invalid_response_keys_server_answer), + no_connection_ack_phoenix_server, + query_server(invalid_event_server_answer), + ], + indirect=True, +) +@pytest.mark.parametrize("query_str", [query1_str]) +async def test_phoenix_channel_query_protocol_error(event_loop, server, query_str): + + from gql.transport.phoenix_channel_websockets import ( + PhoenixChannelWebsocketsTransport, + ) + + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + sample_transport = PhoenixChannelWebsocketsTransport( + channel_name="test_channel", url=url + ) + + query = gql(query_str) + with pytest.raises(TransportProtocolError): + async with Client(transport=sample_transport) as session: + await session.execute(query) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "server", + [ + query_server(generic_error_server_answer), + query_server(error_with_reason_server_answer), + query_server(multiple_errors_server_answer), + query_server(timeout_server_answer), ], indirect=True, ) @@ -104,71 +209,207 @@ async def test_phoenix_channel_query_error(event_loop, server, query_str): await session.execute(query) -invalid_subscription_id_server_answer = ( +query2_str = """ + subscription getContinents { + continents { + code + name + } + } +""" + +default_subscription_server_answer = ( + '{"event":"phx_reply",' + '"payload":' + '{"response":{"subscriptionId":"test_subscription"},' + '"status":"ok"},' + '"ref":2,' + '"topic":"test_topic"}' +) + +ref_is_not_an_integer_server_answer = ( + '{"event":"phx_reply",' + '"payload":' + '{"response":{"subscriptionId":"test_subscription"},' + '"status":"ok"},' + '"ref":"not_an_integer",' + '"topic":"test_topic"}' +) + +missing_ref_server_answer = ( + '{"event":"phx_reply",' + '"payload":' + '{"response":{"subscriptionId":"test_subscription"},' + '"status":"ok"},' + '"topic":"test_topic"}' +) + +missing_subscription_id_server_answer = ( + '{"event":"phx_reply",' + '"payload":' + '{"response":{},"status":"ok"},' + '"ref":2,' + '"topic":"test_topic"}' +) + +null_subscription_id_server_answer = ( + '{"event":"phx_reply",' + '"payload":' + '{"response":{"subscriptionId":null},"status":"ok"},' + '"ref":2,' + '"topic":"test_topic"}' +) + +default_subscription_data_answer = ( '{"event":"subscription:data","payload":' - '{"subscriptionId":"INVALID","result":' + '{"subscriptionId":"test_subscription","result":' '{"data":{"continents":[' '{"code":"AF","name":"Africa"},{"code":"AN","name":"Antarctica"},' '{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},' '{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},' '{"code":"SA","name":"South America"}]}}},' + '"ref":null,' + '"topic":"test_subscription"}' +) + +default_subscription_unsubscribe_answer = ( + '{"event":"phx_reply",' + '"payload":{"response":{"subscriptionId":"test_subscription"},' + '"status":"ok"},' '"ref":3,' '"topic":"test_topic"}' ) -invalid_payload_server_answer = ( +missing_subscription_id_data_answer = ( + '{"event":"subscription:data","payload":' + '{"result":' + '{"data":{"continents":[' + '{"code":"AF","name":"Africa"},{"code":"AN","name":"Antarctica"},' + '{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},' + '{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},' + '{"code":"SA","name":"South America"}]}}},' + '"ref":null,' + '"topic":"test_subscription"}' +) + +null_subscription_id_data_answer = ( + '{"event":"subscription:data","payload":' + '{"subscriptionId":null,"result":' + '{"data":{"continents":[' + '{"code":"AF","name":"Africa"},{"code":"AN","name":"Antarctica"},' + '{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},' + '{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},' + '{"code":"SA","name":"South America"}]}}},' + '"ref":null,' + '"topic":"test_subscription"}' +) + +invalid_subscription_id_data_answer = ( + '{"event":"subscription:data","payload":' + '{"subscriptionId":"INVALID","result":' + '{"data":{"continents":[' + '{"code":"AF","name":"Africa"},{"code":"AN","name":"Antarctica"},' + '{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},' + '{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},' + '{"code":"SA","name":"South America"}]}}},' + '"ref":null,' + '"topic":"test_subscription"}' +) + +invalid_payload_data_answer = ( '{"event":"subscription:data",' '"payload":"INVALID",' - '"ref":3,' - '"topic":"test_topic"}' + '"ref":null,' + '"topic":"test_subscription"}' ) -invalid_result_server_answer = ( +invalid_result_data_answer = ( '{"event":"subscription:data","payload":' - '{"subscriptionId":"test_subscription","result": "INVALID"},' - '"ref":3,' - '"topic":"test_topic"}' + '{"subscriptionId":"test_subscription","result":"INVALID"},' + '"ref":null,' + '"topic":"test_subscription"}' ) -generic_error_server_answer = ( +invalid_result_keys_data_answer = ( + '{"event":"subscription:data",' + '"payload":{"subscriptionId":"test_subscription",' + '"result":{"data":{"continents":null},"invalid":null}},' + '"ref":null,' + '"topic":"test_subscription"}' +) + +invalid_subscription_ref_answer = ( '{"event":"phx_reply",' - '"payload":' - '{"status":"error"},' - '"ref":2,' + '"payload":{"response":{"subscriptionId":"test_subscription"},' + '"status":"ok"},' + '"ref":99,' '"topic":"test_topic"}' ) -protocol_server_answer = '{"event":"unknown"}' - -invalid_payload_subscription_server_answer = ( - '{"event":"phx_reply", "payload":"INVALID", "ref":2, "topic":"test_topic"}' +mismatched_unsubscribe_answer = ( + '{"event":"phx_reply",' + '"payload":{"response":{"subscriptionId":"no_such_subscription"},' + '"status":"ok"},' + '"ref":3,' + '"topic":"test_topic"}' ) -async def no_connection_ack_phoenix_server(ws, path): +def subscription_server( + server_answers=default_subscription_server_answer, + data_answers=default_subscription_data_answer, + unsubscribe_answers=default_subscription_unsubscribe_answer, +): from .conftest import PhoenixChannelServerHelper + import json - await ws.recv() - await PhoenixChannelServerHelper.send_close(ws) - await ws.wait_closed() + async def phoenix_server(ws, path): + await PhoenixChannelServerHelper.send_connection_ack(ws) + await ws.recv() + if server_answers is not None: + for server_answer in ensure_list(server_answers): + await ws.send(server_answer) + if data_answers is not None: + for data_answer in ensure_list(data_answers): + await ws.send(data_answer) + if unsubscribe_answers is not None: + result = await ws.recv() + json_result = json.loads(result) + assert json_result["event"] == "unsubscribe" + for unsubscribe_answer in ensure_list(unsubscribe_answers): + await ws.send(unsubscribe_answer) + else: + await PhoenixChannelServerHelper.send_close(ws) + await ws.wait_closed() + + return phoenix_server @pytest.mark.asyncio @pytest.mark.parametrize( "server", [ - server(invalid_subscription_id_server_answer), - server(invalid_result_server_answer), - server(generic_error_server_answer), - no_connection_ack_phoenix_server, - server(protocol_server_answer), - server(invalid_payload_server_answer), - server(None, invalid_payload_subscription_server_answer), + subscription_server(invalid_subscription_ref_answer), + subscription_server(missing_subscription_id_server_answer), + subscription_server(null_subscription_id_server_answer), + subscription_server( + [default_subscription_server_answer, default_subscription_server_answer] + ), + subscription_server(data_answers=missing_subscription_id_data_answer), + subscription_server(data_answers=null_subscription_id_data_answer), + subscription_server(data_answers=invalid_subscription_id_data_answer), + subscription_server(data_answers=ref_is_not_an_integer_server_answer), + subscription_server(data_answers=missing_ref_server_answer), + subscription_server(data_answers=invalid_payload_data_answer), + subscription_server(data_answers=invalid_result_data_answer), + subscription_server(data_answers=invalid_result_keys_data_answer), ], indirect=True, ) -@pytest.mark.parametrize("query_str", [query1_str]) -async def test_phoenix_channel_protocol_error(event_loop, server, query_str): +@pytest.mark.parametrize("query_str", [query2_str]) +async def test_phoenix_channel_subscription_protocol_error( + event_loop, server, query_str +): from gql.transport.phoenix_channel_websockets import ( PhoenixChannelWebsocketsTransport, @@ -183,17 +424,17 @@ async def test_phoenix_channel_protocol_error(event_loop, server, query_str): query = gql(query_str) with pytest.raises(TransportProtocolError): async with Client(transport=sample_transport) as session: - await session.execute(query) + async for _result in session.subscribe(query): + await asyncio.sleep(10 * MS) + break -server_error_subscription_server_answer = ( - '{"event":"phx_error", "ref":2, "topic":"test_topic"}' -) +server_error_server_answer = '{"event":"phx_error", "ref":2, "topic":"test_topic"}' @pytest.mark.asyncio @pytest.mark.parametrize( - "server", [server(None, server_error_subscription_server_answer)], indirect=True, + "server", [query_server(server_error_server_answer)], indirect=True, ) @pytest.mark.parametrize("query_str", [query1_str]) async def test_phoenix_channel_server_error(event_loop, server, query_str): @@ -212,3 +453,65 @@ async def test_phoenix_channel_server_error(event_loop, server, query_str): with pytest.raises(TransportServerError): async with Client(transport=sample_transport) as session: await session.execute(query) + + +# These cannot be caught by the client +@pytest.mark.asyncio +@pytest.mark.parametrize( + "server", + [ + subscription_server(unsubscribe_answers=invalid_subscription_ref_answer), + subscription_server(unsubscribe_answers=mismatched_unsubscribe_answer), + ], + indirect=True, +) +@pytest.mark.parametrize("query_str", [query2_str]) +async def test_phoenix_channel_unsubscribe_error(event_loop, server, query_str): + + from gql.transport.phoenix_channel_websockets import ( + PhoenixChannelWebsocketsTransport, + ) + + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + + # Reduce close_timeout. These tests will wait for an unsubscribe + # reply that will never come... + sample_transport = PhoenixChannelWebsocketsTransport( + channel_name="test_channel", url=url, close_timeout=1 + ) + + query = gql(query_str) + async with Client(transport=sample_transport) as session: + async for _result in session.subscribe(query): + break + + +# We can force the error if somehow the generator is still running while +# we receive a mismatched unsubscribe answer +@pytest.mark.asyncio +@pytest.mark.parametrize( + "server", + [subscription_server(unsubscribe_answers=mismatched_unsubscribe_answer)], + indirect=True, +) +@pytest.mark.parametrize("query_str", [query2_str]) +async def test_phoenix_channel_unsubscribe_error_forcing(event_loop, server, query_str): + + from gql.transport.phoenix_channel_websockets import ( + PhoenixChannelWebsocketsTransport, + ) + + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + + sample_transport = PhoenixChannelWebsocketsTransport( + channel_name="test_channel", url=url, close_timeout=1 + ) + + query = gql(query_str) + with pytest.raises(TransportProtocolError): + async with Client(transport=sample_transport) as session: + async for _result in session.subscribe(query): + await session.transport._send_stop_message(2) + await asyncio.sleep(10 * MS) diff --git a/tests/test_phoenix_channel_query.py b/tests/test_phoenix_channel_query.py index c3679ac6..b13a8c55 100644 --- a/tests/test_phoenix_channel_query.py +++ b/tests/test_phoenix_channel_query.py @@ -14,6 +14,68 @@ } """ +default_query_server_answer = ( + '{"event":"phx_reply",' + '"payload":' + '{"response":' + '{"data":{"continents":[' + '{"code":"AF","name":"Africa"},{"code":"AN","name":"Antarctica"},' + '{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},' + '{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},' + '{"code":"SA","name":"South America"}]}},' + '"status":"ok"},' + '"ref":2,' + '"topic":"test_topic"}' +) + + +@pytest.fixture +def ws_server_helper(request): + from .conftest import PhoenixChannelServerHelper + + yield PhoenixChannelServerHelper + + +async def query_server(ws, path): + from .conftest import PhoenixChannelServerHelper + + await PhoenixChannelServerHelper.send_connection_ack(ws) + await ws.recv() + await ws.send(default_query_server_answer) + await PhoenixChannelServerHelper.send_close(ws) + await ws.wait_closed() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [query_server], indirect=True) +@pytest.mark.parametrize("query_str", [query1_str]) +async def test_phoenix_channel_query(event_loop, server, query_str): + from gql.transport.phoenix_channel_websockets import ( + PhoenixChannelWebsocketsTransport, + ) + + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + sample_transport = PhoenixChannelWebsocketsTransport( + channel_name="test_channel", url=url + ) + + query = gql(query_str) + async with Client(transport=sample_transport) as session: + result = await session.execute(query) + + print("Client received:", result) + + +query2_str = """ + subscription getContinents { + continents { + code + name + } + } +""" + subscription_server_answer = ( '{"event":"phx_reply",' '"payload":' @@ -24,7 +86,7 @@ '"topic":"test_topic"}' ) -query1_server_answer = ( +subscription_data_server_answer = ( '{"event":"subscription:data","payload":' '{"subscriptionId":"test_subscription","result":' '{"data":{"continents":[' @@ -32,33 +94,39 @@ '{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},' '{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},' '{"code":"SA","name":"South America"}]}}},' + '"ref":null,' + '"topic":"test_subscription"}' +) + +unsubscribe_server_answer = ( + '{"event":"phx_reply",' + '"payload":' + '{"response":' + '{"subscriptionId":"test_subscription"},' + '"status":"ok"},' '"ref":3,' '"topic":"test_topic"}' ) -@pytest.fixture -def ws_server_helper(request): - from .conftest import PhoenixChannelServerHelper - - yield PhoenixChannelServerHelper - - -async def phoenix_server(ws, path): +async def subscription_server(ws, path): from .conftest import PhoenixChannelServerHelper await PhoenixChannelServerHelper.send_connection_ack(ws) await ws.recv() await ws.send(subscription_server_answer) - await ws.send(query1_server_answer) - await PhoenixChannelServerHelper.send_close(ws) + await ws.send(subscription_data_server_answer) + await ws.recv() + await ws.send(unsubscribe_server_answer) + # Unsubscribe will remove the listener + # await PhoenixChannelServerHelper.send_close(ws) await ws.wait_closed() @pytest.mark.asyncio -@pytest.mark.parametrize("server", [phoenix_server], indirect=True) -@pytest.mark.parametrize("query_str", [query1_str]) -async def test_phoenix_channel_simple_query(event_loop, server, query_str): +@pytest.mark.parametrize("server", [subscription_server], indirect=True) +@pytest.mark.parametrize("query_str", [query2_str]) +async def test_phoenix_channel_subscription(event_loop, server, query_str): from gql.transport.phoenix_channel_websockets import ( PhoenixChannelWebsocketsTransport, ) @@ -69,8 +137,11 @@ async def test_phoenix_channel_simple_query(event_loop, server, query_str): channel_name="test_channel", url=url ) + first_result = None query = gql(query_str) async with Client(transport=sample_transport) as session: - result = await session.execute(query) + async for result in session.subscribe(query): + first_result = result + break - print("Client received:", result) + print("Client received:", first_result) diff --git a/tests/test_phoenix_channel_subscription.py b/tests/test_phoenix_channel_subscription.py index ef46db47..3c6ec2b2 100644 --- a/tests/test_phoenix_channel_subscription.py +++ b/tests/test_phoenix_channel_subscription.py @@ -1,5 +1,6 @@ import asyncio import json +import sys import pytest from parse import search @@ -9,26 +10,77 @@ # Marking all tests in this file with the websockets marker pytestmark = pytest.mark.websockets -subscription_server_answer = ( - '{"event":"phx_reply",' - '"payload":' - '{"response":' - '{"subscriptionId":"test_subscription"},' - '"status":"ok"},' - '"ref":2,' - '"topic":"test_topic"}' +test_channel = "test_channel" +test_subscription_id = "test_subscription" + +# A server should send this after receiving a 'phx_leave' request message. +# 'query_id' should be the value of the 'ref' in the 'phx_leave' request. +# With only one listener, the transport is closed automatically when +# it exits a subscription, so this is not used in current tests. +channel_leave_reply_template = ( + "{{" + '"topic":"{channel_name}",' + '"event":"phx_reply",' + '"payload":{{' + '"response":{{}},' + '"status":"ok"' + "}}," + '"ref":{query_id}' + "}}" ) -countdown_server_answer = ( - '{{"event":"subscription:data",' - '"payload":{{"subscriptionId":"test_subscription","result":' - '{{"data":{{"number":{number}}}}}}},' - '"ref":{query_id}}}' +# A server should send this after sending the 'channel_leave_reply' +# above, to confirm to the client that the channel was actually closed. +# With only one listener, the transport is closed automatically when +# it exits a subscription, so this is not used in current tests. +channel_close_reply_template = ( + "{{" + '"topic":"{channel_name}",' + '"event":"phx_close",' + '"payload":{{}},' + '"ref":null' + "}}" +) + +# A server sends this when it receives a 'subscribe' request, +# after creating a unique subscription id. 'query_id' should be the +# value of the 'ref' in the 'subscribe' request. +subscription_reply_template = ( + "{{" + '"topic":"{channel_name}",' + '"event":"phx_reply",' + '"payload":{{' + '"response":{{' + '"subscriptionId":"{subscription_id}"' + "}}," + '"status":"ok"' + "}}," + '"ref":{query_id}' + "}}" +) + +countdown_data_template = ( + "{{" + '"topic":"{subscription_id}",' + '"event":"subscription:data",' + '"payload":{{' + '"subscriptionId":"{subscription_id}",' + '"result":{{' + '"data":{{' + '"countdown":{{' + '"number":{number}' + "}}" + "}}" + "}}" + "}}," + '"ref":null' + "}}" ) async def server_countdown(ws, path): import websockets + from .conftest import MS, PhoenixChannelServerHelper try: @@ -37,20 +89,29 @@ async def server_countdown(ws, path): result = await ws.recv() json_result = json.loads(result) assert json_result["event"] == "doc" - payload = json_result["payload"] - query = payload["query"] + channel_name = json_result["topic"] query_id = json_result["ref"] + payload = json_result["payload"] + query = payload["query"] count_found = search("count: {:d}", query) count = count_found[0] print(f"Countdown started from: {count}") - await ws.send(subscription_server_answer) + await ws.send( + subscription_reply_template.format( + subscription_id=test_subscription_id, + channel_name=channel_name, + query_id=query_id, + ) + ) async def counting_coro(): for number in range(count, -1, -1): await ws.send( - countdown_server_answer.format(query_id=query_id, number=number) + countdown_data_template.format( + subscription_id=test_subscription_id, number=number + ) ) await asyncio.sleep(2 * MS) @@ -59,12 +120,23 @@ async def counting_coro(): async def stopping_coro(): nonlocal counting_task while True: - result = await ws.recv() json_result = json.loads(result) - if json_result["type"] == "stop" and json_result["id"] == str(query_id): - print("Cancelling counting task now") + if json_result["event"] == "unsubscribe": + query_id = json_result["ref"] + payload = json_result["payload"] + subscription_id = payload["subscriptionId"] + assert subscription_id == test_subscription_id + + print("Sending unsubscribe reply") + await ws.send( + subscription_reply_template.format( + subscription_id=subscription_id, + channel_name=channel_name, + query_id=query_id, + ) + ) counting_task.cancel() stopping_task = asyncio.ensure_future(stopping_coro()) @@ -74,16 +146,17 @@ async def stopping_coro(): except asyncio.CancelledError: print("Now counting task is cancelled") - stopping_task.cancel() - + # Waiting for a clean stop try: - await stopping_task + await asyncio.wait_for(stopping_task, 3) except asyncio.CancelledError: print("Now stopping task is cancelled") + except asyncio.TimeoutError: + print("Now stopping task is in timeout") - await PhoenixChannelServerHelper.send_close(ws) + # await PhoenixChannelServerHelper.send_close(ws) except websockets.exceptions.ConnectionClosedOK: - pass + print("Connection closed") finally: await ws.wait_closed() @@ -100,15 +173,29 @@ async def stopping_coro(): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) -async def test_phoenix_channel_subscription(event_loop, server, subscription_str): +@pytest.mark.parametrize("end_count", [0, 5]) +async def test_phoenix_channel_subscription( + event_loop, server, subscription_str, end_count +): + """Parameterized test. + + :param end_count: Target count at which the test will 'break' to unsubscribe. + """ + import logging + from gql.transport.phoenix_channel_websockets import ( PhoenixChannelWebsocketsTransport, ) + from gql.transport.phoenix_channel_websockets import log as phoenix_logger + from gql.transport.websockets import log as websockets_logger + + websockets_logger.setLevel(logging.DEBUG) + phoenix_logger.setLevel(logging.DEBUG) path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" sample_transport = PhoenixChannelWebsocketsTransport( - channel_name="test_channel", url=url + channel_name=test_channel, url=url, close_timeout=5 ) count = 10 @@ -116,39 +203,156 @@ async def test_phoenix_channel_subscription(event_loop, server, subscription_str async with Client(transport=sample_transport) as session: async for result in session.subscribe(subscription): - - number = result["number"] + number = result["countdown"]["number"] print(f"Number received: {number}") assert number == count + if number == end_count: + # Note: we need to run generator.aclose() here or the finally block in + # the subscribe will not be reached in pypy3 (python version 3.6.1) + # In more recent versions, 'break' will trigger __aexit__. + if sys.version_info < (3, 7): + await session._generator.aclose() + print("break") + break + count -= 1 - assert count == -1 + assert count == end_count + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_phoenix_channel_subscription_no_break( + event_loop, server, subscription_str +): + import logging + + from gql.transport.phoenix_channel_websockets import ( + PhoenixChannelWebsocketsTransport, + ) + from gql.transport.phoenix_channel_websockets import log as phoenix_logger + from gql.transport.websockets import log as websockets_logger + + websockets_logger.setLevel(logging.DEBUG) + phoenix_logger.setLevel(logging.DEBUG) + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + + async def testing_stopping_without_break(): + + sample_transport = PhoenixChannelWebsocketsTransport( + channel_name=test_channel, url=url, close_timeout=5 + ) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async with Client(transport=sample_transport) as session: + async for result in session.subscribe(subscription): + number = result["countdown"]["number"] + print(f"Number received: {number}") + + # Simulate a slow consumer + await asyncio.sleep(0.1) -heartbeat_server_answer = ( - '{{"event":"subscription:data",' - '"payload":{{"subscriptionId":"test_subscription","result":' - '{{"data":{{"heartbeat_count":{count}}}}}}},' - '"ref":1}}' + if number == 9: + # When we consume the number 9 here in the async generator, + # all the 10 numbers have already been sent by the backend and + # are present in the listener queue + # we simulate here an unsubscribe message + # In that case, all the 10 numbers should be consumed in the + # generator and then the generator should be closed properly + await session.transport._send_stop_message(2) + + assert number == count + + count -= 1 + + assert count == -1 + + try: + await asyncio.wait_for(testing_stopping_without_break(), timeout=5) + except asyncio.TimeoutError: + assert False, "The async generator did not stop" + + +heartbeat_data_template = ( + "{{" + '"topic":"{subscription_id}",' + '"event":"subscription:data",' + '"payload":{{' + '"subscriptionId":"{subscription_id}",' + '"result":{{' + '"data":{{' + '"heartbeat":{{' + '"heartbeat_count":{count}' + "}}" + "}}" + "}}" + "}}," + '"ref":null' + "}}" ) async def phoenix_heartbeat_server(ws, path): + import websockets + from .conftest import PhoenixChannelServerHelper - await PhoenixChannelServerHelper.send_connection_ack(ws) - await ws.recv() - await ws.send(subscription_server_answer) + try: + await PhoenixChannelServerHelper.send_connection_ack(ws) - for i in range(3): - heartbeat_result = await ws.recv() - json_result = json.loads(heartbeat_result) - assert json_result["event"] == "heartbeat" - await ws.send(heartbeat_server_answer.format(count=i)) + result = await ws.recv() + json_result = json.loads(result) + assert json_result["event"] == "doc" + channel_name = json_result["topic"] + query_id = json_result["ref"] - await PhoenixChannelServerHelper.send_close(ws) - await ws.wait_closed() + await ws.send( + subscription_reply_template.format( + subscription_id=test_subscription_id, + channel_name=channel_name, + query_id=query_id, + ) + ) + + async def heartbeat_coro(): + i = 0 + while True: + heartbeat_result = await ws.recv() + json_result = json.loads(heartbeat_result) + if json_result["event"] == "heartbeat": + await ws.send( + heartbeat_data_template.format( + subscription_id=test_subscription_id, count=i + ) + ) + i = i + 1 + elif json_result["event"] == "unsubscribe": + query_id = json_result["ref"] + payload = json_result["payload"] + subscription_id = payload["subscriptionId"] + assert subscription_id == test_subscription_id + + print("Sending unsubscribe reply") + await ws.send( + subscription_reply_template.format( + subscription_id=subscription_id, + channel_name=channel_name, + query_id=query_id, + ) + ) + + await asyncio.wait_for(heartbeat_coro(), 60) + # await PhoenixChannelServerHelper.send_close(ws) + except websockets.exceptions.ConnectionClosedOK: + print("Connection closed") + finally: + await ws.wait_closed() heartbeat_subscription_str = """ @@ -171,15 +375,23 @@ async def test_phoenix_channel_heartbeat(event_loop, server, subscription_str): path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" sample_transport = PhoenixChannelWebsocketsTransport( - channel_name="test_channel", url=url, heartbeat_interval=1 + channel_name=test_channel, url=url, heartbeat_interval=0.1 ) subscription = gql(heartbeat_subscription_str) async with Client(transport=sample_transport) as session: i = 0 async for result in session.subscribe(subscription): - heartbeat_count = result["heartbeat_count"] + heartbeat_count = result["heartbeat"]["heartbeat_count"] print(f"Heartbeat count received: {heartbeat_count}") assert heartbeat_count == i + if heartbeat_count == 5: + # Note: we need to run generator.aclose() here or the finally block in + # the subscribe will not be reached in pypy3 (python version 3.6.1) + # In more recent versions, 'break' will trigger __aexit__. + if sys.version_info < (3, 7): + await session._generator.aclose() + break + i += 1