From c7d02cd63e6289a2860398cd553286b36f63e80c Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Wed, 11 Aug 2021 01:46:22 +0200 Subject: [PATCH 01/24] Initialize heartbeat_task correctly --- gql/transport/phoenix_channel_websockets.py | 1 + 1 file changed, 1 insertion(+) diff --git a/gql/transport/phoenix_channel_websockets.py b/gql/transport/phoenix_channel_websockets.py index 27e58f2a..26ab159b 100644 --- a/gql/transport/phoenix_channel_websockets.py +++ b/gql/transport/phoenix_channel_websockets.py @@ -34,6 +34,7 @@ def __init__( """ self.channel_name = channel_name self.heartbeat_interval = heartbeat_interval + self.heartbeat_task: Optional[asyncio.Future] = None self.subscription_ids_to_query_ids: Dict[str, int] = {} super(PhoenixChannelWebsocketsTransport, self).__init__(*args, **kwargs) From fdb8307245309229e71fd2d4c9527caebad3d769 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Wed, 11 Aug 2021 01:47:52 +0200 Subject: [PATCH 02/24] Add debug logs to phoenix subscription test --- tests/test_phoenix_channel_subscription.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/test_phoenix_channel_subscription.py b/tests/test_phoenix_channel_subscription.py index ef46db47..16ce5c34 100644 --- a/tests/test_phoenix_channel_subscription.py +++ b/tests/test_phoenix_channel_subscription.py @@ -101,9 +101,13 @@ async def stopping_coro(): @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_phoenix_channel_subscription(event_loop, server, subscription_str): + import logging from gql.transport.phoenix_channel_websockets import ( PhoenixChannelWebsocketsTransport, ) + from gql.transport.websockets import log as websockets_logger + + websockets_logger.setLevel(logging.DEBUG) path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" From f9e74e95bed343abc11528e04f409da66a71e259 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Wed, 11 Aug 2021 01:48:51 +0200 Subject: [PATCH 03/24] Fix assertion error when transport is already closed when _receive is called --- gql/transport/websockets.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index 7e26f31c..2dce0545 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -175,8 +175,9 @@ async def _receive(self) -> str: """Wait the next message from the websocket connection and log the answer """ - # We should always have an active websocket connection here - assert self.websocket is not None + # It is possible that the websocket has been already closed in another task + if self.websocket is None: + raise TransportClosed("Transport is already closed") # Wait for the next websocket frame. Can raise ConnectionClosed data: Data = await self.websocket.recv() @@ -387,6 +388,8 @@ async def _receive_data_loop(self) -> None: except (ConnectionClosed, TransportProtocolError) as e: await self._fail(e, clean_close=False) break + except TransportClosed: + break # Parse the answer try: From 1e847d5bc7b1812726e15735e20bd2a311c5b555 Mon Sep 17 00:00:00 2001 From: Peter Zingg Date: Fri, 13 Aug 2021 07:41:07 -0700 Subject: [PATCH 04/24] Stop websocket subscriptions on TransportClosed --- gql/transport/websockets.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index 2dce0545..d0ef46f2 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -485,16 +485,14 @@ async def subscribe( ) break - except (asyncio.CancelledError, GeneratorExit) as e: + except (asyncio.CancelledError, GeneratorExit, TransportClosed) as e: log.debug("Exception in subscribe: " + repr(e)) if listener.send_stop: await self._send_stop_message(query_id) listener.send_stop = False finally: - del self.listeners[query_id] - if len(self.listeners) == 0: - self._no_more_listeners.set() + self._remove_listener(query_id) async def execute( self, @@ -612,6 +610,19 @@ async def connect(self) -> None: log.debug("connect: done") + def _remove_listener(self, query_id) -> None: + """After exiting from a subscription, remove the listener and + signal an event if this was the last listener for the client. + """ + if query_id in self.listeners: + del self.listeners[query_id] + + remaining = len(self.listeners) + log.debug(f"listener {query_id} deleted, {remaining} remaining") + + if remaining == 0: + self._no_more_listeners.set() + async def _clean_close(self, e: Exception) -> None: """Coroutine which will: From caa671de0cd09212cf1166f32d00625bdc968bc5 Mon Sep 17 00:00:00 2001 From: Peter Zingg Date: Fri, 13 Aug 2021 07:47:22 -0700 Subject: [PATCH 05/24] Add logger to PhoenixChannelWebsocketsTransport --- gql/transport/phoenix_channel_websockets.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/gql/transport/phoenix_channel_websockets.py b/gql/transport/phoenix_channel_websockets.py index 26ab159b..a4966d71 100644 --- a/gql/transport/phoenix_channel_websockets.py +++ b/gql/transport/phoenix_channel_websockets.py @@ -1,5 +1,6 @@ import asyncio import json +import logging from typing import Any, Dict, Optional, Tuple from graphql import DocumentNode, ExecutionResult, print_ast @@ -12,6 +13,8 @@ ) from .websockets import WebsocketsTransport +log = logging.getLogger(__name__) + class PhoenixChannelWebsocketsTransport(WebsocketsTransport): """The PhoenixChannelWebsocketsTransport is an **EXPERIMENTAL** async transport @@ -32,8 +35,8 @@ def __init__( :param heartbeat_interval: Interval in second between each heartbeat messages sent by the client """ - self.channel_name = channel_name - self.heartbeat_interval = heartbeat_interval + self.channel_name: str = channel_name + self.heartbeat_interval: float = heartbeat_interval self.heartbeat_task: Optional[asyncio.Future] = None self.subscription_ids_to_query_ids: Dict[str, int] = {} super(PhoenixChannelWebsocketsTransport, self).__init__(*args, **kwargs) @@ -179,7 +182,7 @@ def _parse_answer( answer_id = self.subscription_ids_to_query_ids[subscription_id] except KeyError: raise ValueError( - f"subscription '{subscription_id}' has not been registerd" + f"subscription '{subscription_id}' has not been registered" ) result = payload.get("result") @@ -238,6 +241,7 @@ def _parse_answer( raise ValueError except ValueError as e: + log.error(f"Error parsing answer '{answer}' " + repr(e)) raise TransportProtocolError( "Server did not return a GraphQL result" ) from e From 3b3cec7eb81ae0fd149b0cca4b0645233057a3fe Mon Sep 17 00:00:00 2001 From: Peter Zingg Date: Fri, 13 Aug 2021 08:01:28 -0700 Subject: [PATCH 06/24] Implement Absinthe unsubscribe protocol in PhoenixChannelWebsocketsTransport --- gql/transport/phoenix_channel_websockets.py | 65 +++++++++++++++++---- 1 file changed, 55 insertions(+), 10 deletions(-) diff --git a/gql/transport/phoenix_channel_websockets.py b/gql/transport/phoenix_channel_websockets.py index a4966d71..0caab07f 100644 --- a/gql/transport/phoenix_channel_websockets.py +++ b/gql/transport/phoenix_channel_websockets.py @@ -27,11 +27,12 @@ class PhoenixChannelWebsocketsTransport(WebsocketsTransport): """ def __init__( - self, channel_name: str, heartbeat_interval: float = 30, *args, **kwargs + self, channel_name: str = "__absinthe__:control", heartbeat_interval: float = 30, *args, **kwargs ) -> None: """Initialize the transport with the given parameters. - :param channel_name: Channel on the server this transport will join + :param channel_name: Channel on the server this transport will join. + The default for Absinthe servers is "__absinthe__:control" :param heartbeat_interval: Interval in second between each heartbeat messages sent by the client """ @@ -39,6 +40,7 @@ def __init__( self.heartbeat_interval: float = heartbeat_interval self.heartbeat_task: Optional[asyncio.Future] = None self.subscription_ids_to_query_ids: Dict[str, int] = {} + self.unsubscribe_answer_ids: Dict[int, int] = {} super(PhoenixChannelWebsocketsTransport, self).__init__(*args, **kwargs) async def _send_init_message_and_wait_ack(self) -> None: @@ -93,11 +95,37 @@ async def heartbeat_coro(): self.heartbeat_task = asyncio.ensure_future(heartbeat_coro()) - async def _send_stop_message(self, query_id: int) -> None: - try: - await self.listeners[query_id].put(("complete", None)) - except KeyError: # pragma: no cover - pass + async def _send_stop_message(self, listener_query_id: int) -> None: + """Send an 'unsubscribe' message to the Phoenix Channel referencing + the listener's query_id, saving the query_id of the message. + + The server should afterwards return a 'phx_reply' message with + the same query_id and subscription_id of the 'unsubscribe' request. + """ + query_id = self.next_query_id + self.next_query_id += 1 + + subscription_id = None + for sub_id, q_id in self.subscription_ids_to_query_ids.items(): + if q_id == listener_query_id: + subscription_id = sub_id + break + + if subscription_id is None: + raise ValueError(f"No subscription for {listener_query_id}") + + # Save the ref so it can be matched in the reply + self.unsubscribe_answer_ids[query_id] = listener_query_id + unsubscribe_message = json.dumps( + { + "topic": self.channel_name, + "event": "unsubscribe", + "payload": {"subscriptionId": subscription_id}, + "ref": query_id, + } + ) + + await self._send(unsubscribe_message) async def _send_connection_terminate_message(self) -> None: """Send a phx_leave message to disconnect from the provided channel. @@ -156,7 +184,7 @@ def _parse_answer( Returns a list consisting of: - the answer_type (between: - 'heartbeat', 'data', 'reply', 'error', 'close') + 'heartbeat', 'data', 'reply', 'error', 'unsubscribe') - the answer id (Integer) if received or None - an execution Result if the answer_type is 'data' or None """ @@ -207,6 +235,9 @@ def _parse_answer( status = str(payload.get("status")) + # Unsubscription reply? + unsubscribe_listener_id = self.unsubscribe_answer_ids.pop(answer_id, None) + if status == "ok": answer_type = "reply" @@ -214,7 +245,16 @@ def _parse_answer( if isinstance(response, dict) and "subscriptionId" in response: subscription_id = str(response.get("subscriptionId")) - self.subscription_ids_to_query_ids[subscription_id] = answer_id + if unsubscribe_listener_id is not None: + + answer_id = unsubscribe_listener_id + answer_type = "unsubscribe" + + if self.subscription_ids_to_query_ids.get(subscription_id) != unsubscribe_listener_id: + raise ValueError(f"Listener {unsubscribe_listener_id} referenced in unsubscribe reply does not exist") + else: + # Subscription reply + self.subscription_ids_to_query_ids[subscription_id] = answer_id elif status == "error": response = payload.get("response") @@ -254,7 +294,12 @@ async def _handle_answer( answer_id: Optional[int], execution_result: Optional[ExecutionResult], ) -> None: - if answer_type == "close": + if answer_type == "unsubscribe": + # Remove the listener here, to possibly signal + # that it is the last listener in the session. + assert answer_id is not None + self._remove_listener(answer_id) + elif answer_type == "close": await self.close() else: await super()._handle_answer(answer_type, answer_id, execution_result) From 4ec38ff05f60eb988d81e135b6884e4fb4a01365 Mon Sep 17 00:00:00 2001 From: Peter Zingg Date: Fri, 13 Aug 2021 09:19:06 -0700 Subject: [PATCH 07/24] Add unsubscribe to Phoenix Channel subscription tests --- tests/test_phoenix_channel_subscription.py | 208 ++++++++++++++++----- 1 file changed, 158 insertions(+), 50 deletions(-) diff --git a/tests/test_phoenix_channel_subscription.py b/tests/test_phoenix_channel_subscription.py index 16ce5c34..8ff191c5 100644 --- a/tests/test_phoenix_channel_subscription.py +++ b/tests/test_phoenix_channel_subscription.py @@ -9,23 +9,72 @@ # Marking all tests in this file with the websockets marker pytestmark = pytest.mark.websockets -subscription_server_answer = ( - '{"event":"phx_reply",' - '"payload":' - '{"response":' - '{"subscriptionId":"test_subscription"},' - '"status":"ok"},' - '"ref":2,' - '"topic":"test_topic"}' +test_channel = "test_channel" +test_subscription_id = "test_subscription" + +# A server should send this after receiving a 'phx_leave' request message. +# 'query_id' should be the value of the 'ref' in the 'phx_leave' request. +# With only one listener, the transport is closed automatically when +# it exits a subscription, so this is not used in current tests. +channel_leave_reply_template = ( + '{{' + '"topic":"{channel_name}",' + '"event":"phx_reply",' + '"payload":{{' + '"response":{{}},' + '"status":"ok"' + '}},' + '"ref":{query_id}' + '}}' ) -countdown_server_answer = ( - '{{"event":"subscription:data",' - '"payload":{{"subscriptionId":"test_subscription","result":' - '{{"data":{{"number":{number}}}}}}},' - '"ref":{query_id}}}' +# A server should send this after sending the 'channel_leave_reply' +# above, to confirm to the client that the channel was actually closed. +# With only one listener, the transport is closed automatically when +# it exits a subscription, so this is not used in current tests. +channel_close_reply_template = ( + '{{' + '"topic":"{channel_name}",' + '"event":"phx_close",' + '"payload":{},' + '"ref":null' + '}}' ) +# A server sends this when it receives a 'subscribe' request, +# after creating a unique subscription id. 'query_id' should be the +# value of the 'ref' in the 'subscribe' request. +subscription_reply_template = ( + '{{' + '"topic":"{channel_name}",' + '"event":"phx_reply",' + '"payload":{{' + '"response":{{' + '"subscriptionId":"{subscription_id}"' + '}},' + '"status":"ok"' + '}},' + '"ref":{query_id}' + '}}' +) + +countdown_data_template = ( + '{{' + '"topic":"{subscription_id}",' + '"event":"subscription:data",' + '"payload":{{' + '"subscriptionId":"{subscription_id}",' + '"result":{{' + '"data":{{' + '"countdown":{{' + '"number":{number}' + '}}' + '}}' + '}}' + '}},' + '"ref":null' + '}}' +) async def server_countdown(ws, path): import websockets @@ -37,20 +86,26 @@ async def server_countdown(ws, path): result = await ws.recv() json_result = json.loads(result) assert json_result["event"] == "doc" - payload = json_result["payload"] - query = payload["query"] + channel_name = json_result["topic"] query_id = json_result["ref"] + payload = json_result["payload"] + query = payload["query"] count_found = search("count: {:d}", query) count = count_found[0] print(f"Countdown started from: {count}") - await ws.send(subscription_server_answer) + await ws.send(subscription_reply_template.format( + subscription_id=test_subscription_id, + channel_name=channel_name, + query_id=query_id)) async def counting_coro(): for number in range(count, -1, -1): await ws.send( - countdown_server_answer.format(query_id=query_id, number=number) + countdown_data_template.format( + subscription_id=test_subscription_id, + number=number) ) await asyncio.sleep(2 * MS) @@ -59,13 +114,19 @@ async def counting_coro(): async def stopping_coro(): nonlocal counting_task while True: - result = await ws.recv() json_result = json.loads(result) - if json_result["type"] == "stop" and json_result["id"] == str(query_id): - print("Cancelling counting task now") - counting_task.cancel() + if json_result["event"] == "unsubscribe": + query_id = json_result["ref"] + payload = json_result["payload"] + if payload["subscriptionId"] == test_subscription_id: + print("Sending unsubscribe reply") + counting_task.cancel() + await ws.send(subscription_reply_template.format( + subscription_id=test_subscription_id, + channel_name=channel_name, + query_id=query_id)) stopping_task = asyncio.ensure_future(stopping_coro()) @@ -81,13 +142,12 @@ async def stopping_coro(): except asyncio.CancelledError: print("Now stopping task is cancelled") - await PhoenixChannelServerHelper.send_close(ws) + # await PhoenixChannelServerHelper.send_close(ws) except websockets.exceptions.ConnectionClosedOK: - pass + print("Connection closed") finally: await ws.wait_closed() - countdown_subscription_str = """ subscription {{ countdown (count: {count}) {{ @@ -100,19 +160,25 @@ async def stopping_coro(): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) -async def test_phoenix_channel_subscription(event_loop, server, subscription_str): +@pytest.mark.parametrize("end_count", [0, 5]) +async def test_phoenix_channel_subscription(event_loop, server, subscription_str, end_count): + """Parameterized test. + + :param end_count: Target count at which the test will 'break' to unsubscribe. + """ import logging from gql.transport.phoenix_channel_websockets import ( PhoenixChannelWebsocketsTransport, ) from gql.transport.websockets import log as websockets_logger - + from gql.transport.phoenix_channel_websockets import log as phoenix_logger websockets_logger.setLevel(logging.DEBUG) + phoenix_logger.setLevel(logging.DEBUG) path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" sample_transport = PhoenixChannelWebsocketsTransport( - channel_name="test_channel", url=url + channel_name=test_channel, url=url ) count = 10 @@ -120,40 +186,79 @@ async def test_phoenix_channel_subscription(event_loop, server, subscription_str async with Client(transport=sample_transport) as session: async for result in session.subscribe(subscription): - - number = result["number"] + number = result["countdown"]["number"] print(f"Number received: {number}") assert number == count + if number == end_count: + # Breaking will unsubscribe + break count -= 1 - assert count == -1 - - -heartbeat_server_answer = ( - '{{"event":"subscription:data",' - '"payload":{{"subscriptionId":"test_subscription","result":' - '{{"data":{{"heartbeat_count":{count}}}}}}},' - '"ref":1}}' + assert count == end_count + + +heartbeat_data_template = ( + '{{' + '"topic":"{subscription_id}",' + '"event":"subscription:data",' + '"payload":{{' + '"subscriptionId":"{subscription_id}",' + '"result":{{' + '"data":{{' + '"heartbeat":{{' + '"heartbeat_count":{count}' + '}}' + '}}' + '}}' + '}},' + '"ref":null' + '}}' ) - async def phoenix_heartbeat_server(ws, path): from .conftest import PhoenixChannelServerHelper - await PhoenixChannelServerHelper.send_connection_ack(ws) - await ws.recv() - await ws.send(subscription_server_answer) + try: + await PhoenixChannelServerHelper.send_connection_ack(ws) - for i in range(3): - heartbeat_result = await ws.recv() - json_result = json.loads(heartbeat_result) - assert json_result["event"] == "heartbeat" - await ws.send(heartbeat_server_answer.format(count=i)) + result = await ws.recv() + json_result = json.loads(result) + assert json_result["event"] == "doc" + channel_name = json_result["topic"] + query_id = json_result["ref"] - await PhoenixChannelServerHelper.send_close(ws) - await ws.wait_closed() + await ws.send(subscription_reply_template.format( + subscription_id=test_subscription_id, + channel_name=channel_name, + query_id=query_id)) + async def heartbeat_coro(): + i = 0 + while True: + heartbeat_result = await ws.recv() + json_result = json.loads(heartbeat_result) + if json_result["event"] == "heartbeat": + await ws.send(heartbeat_data_template.format( + subscription_id=test_subscription_id, + count=i)) + i = i + 1 + elif json_result["event"] == "unsubscribe": + query_id = json_result["ref"] + payload = json_result["payload"] + if payload["subscriptionId"] == test_subscription_id: + print("Sending unsubscribe reply") + await ws.send(subscription_reply_template.format( + subscription_id=test_subscription_id, + channel_name=channel_name, + query_id=query_id)) + + await asyncio.wait_for(heartbeat_coro(), 60) + # await PhoenixChannelServerHelper.send_close(ws) + except websockets.exceptions.ConnectionClosedOK: + print("Connection closed") + finally: + await ws.wait_closed() heartbeat_subscription_str = """ subscription { @@ -175,15 +280,18 @@ async def test_phoenix_channel_heartbeat(event_loop, server, subscription_str): path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" sample_transport = PhoenixChannelWebsocketsTransport( - channel_name="test_channel", url=url, heartbeat_interval=1 + channel_name=test_channel, url=url, heartbeat_interval=1 ) subscription = gql(heartbeat_subscription_str) async with Client(transport=sample_transport) as session: i = 0 async for result in session.subscribe(subscription): - heartbeat_count = result["heartbeat_count"] + heartbeat_count = result["heartbeat"]["heartbeat_count"] print(f"Heartbeat count received: {heartbeat_count}") assert heartbeat_count == i + if heartbeat_count == 5: + # Breaking will unsubscribe + break i += 1 From 018ab64270e4ad9c584e8467c981835f4edcc514 Mon Sep 17 00:00:00 2001 From: Peter Zingg Date: Fri, 13 Aug 2021 10:12:59 -0700 Subject: [PATCH 08/24] Fix typo in unused template --- tests/test_phoenix_channel_subscription.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_phoenix_channel_subscription.py b/tests/test_phoenix_channel_subscription.py index 8ff191c5..1e2c2109 100644 --- a/tests/test_phoenix_channel_subscription.py +++ b/tests/test_phoenix_channel_subscription.py @@ -36,7 +36,7 @@ '{{' '"topic":"{channel_name}",' '"event":"phx_close",' - '"payload":{},' + '"payload":{{}},' '"ref":null' '}}' ) From 34c3eb475d4081b11bcef0529b526061917e011a Mon Sep 17 00:00:00 2001 From: Peter Zingg Date: Fri, 13 Aug 2021 14:56:45 -0700 Subject: [PATCH 09/24] Add generator aclose in Phoenix Channel tests for Python<3.7 --- gql/transport/phoenix_channel_websockets.py | 40 ++-- gql/transport/websockets.py | 2 +- tests/test_phoenix_channel_subscription.py | 218 ++++++++++++-------- 3 files changed, 157 insertions(+), 103 deletions(-) diff --git a/gql/transport/phoenix_channel_websockets.py b/gql/transport/phoenix_channel_websockets.py index 0caab07f..0313f559 100644 --- a/gql/transport/phoenix_channel_websockets.py +++ b/gql/transport/phoenix_channel_websockets.py @@ -27,7 +27,11 @@ class PhoenixChannelWebsocketsTransport(WebsocketsTransport): """ def __init__( - self, channel_name: str = "__absinthe__:control", heartbeat_interval: float = 30, *args, **kwargs + self, + channel_name: str = "__absinthe__:control", + heartbeat_interval: float = 30, + *args, + **kwargs, ) -> None: """Initialize the transport with the given parameters. @@ -40,7 +44,7 @@ def __init__( self.heartbeat_interval: float = heartbeat_interval self.heartbeat_task: Optional[asyncio.Future] = None self.subscription_ids_to_query_ids: Dict[str, int] = {} - self.unsubscribe_answer_ids: Dict[int, int] = {} + self.unsub_to_listener_query_ids: Dict[int, int] = {} super(PhoenixChannelWebsocketsTransport, self).__init__(*args, **kwargs) async def _send_init_message_and_wait_ack(self) -> None: @@ -112,10 +116,12 @@ async def _send_stop_message(self, listener_query_id: int) -> None: break if subscription_id is None: - raise ValueError(f"No subscription for {listener_query_id}") + raise ValueError( + f"No subscription for {listener_query_id}" + ) # pragma: no cover # Save the ref so it can be matched in the reply - self.unsubscribe_answer_ids[query_id] = listener_query_id + self.unsub_to_listener_query_ids[query_id] = listener_query_id unsubscribe_message = json.dumps( { "topic": self.channel_name, @@ -128,8 +134,7 @@ async def _send_stop_message(self, listener_query_id: int) -> None: await self._send(unsubscribe_message) async def _send_connection_terminate_message(self) -> None: - """Send a phx_leave message to disconnect from the provided channel. - """ + """Send a phx_leave message to disconnect from the provided channel.""" query_id = self.next_query_id self.next_query_id += 1 @@ -236,7 +241,9 @@ def _parse_answer( status = str(payload.get("status")) # Unsubscription reply? - unsubscribe_listener_id = self.unsubscribe_answer_ids.pop(answer_id, None) + listener_query_id = self.unsub_to_listener_query_ids.pop( + answer_id, None + ) if status == "ok": @@ -245,16 +252,24 @@ def _parse_answer( if isinstance(response, dict) and "subscriptionId" in response: subscription_id = str(response.get("subscriptionId")) - if unsubscribe_listener_id is not None: + if listener_query_id is not None: - answer_id = unsubscribe_listener_id + answer_id = listener_query_id answer_type = "unsubscribe" - if self.subscription_ids_to_query_ids.get(subscription_id) != unsubscribe_listener_id: - raise ValueError(f"Listener {unsubscribe_listener_id} referenced in unsubscribe reply does not exist") + if ( + self.subscription_ids_to_query_ids.get(subscription_id) + != listener_query_id + ): + raise ValueError( + f"Listener {listener_query_id} " + "in unsubscribe reply does not exist" + ) else: # Subscription reply - self.subscription_ids_to_query_ids[subscription_id] = answer_id + self.subscription_ids_to_query_ids[ + subscription_id + ] = answer_id elif status == "error": response = payload.get("response") @@ -297,7 +312,6 @@ async def _handle_answer( if answer_type == "unsubscribe": # Remove the listener here, to possibly signal # that it is the last listener in the session. - assert answer_id is not None self._remove_listener(answer_id) elif answer_type == "close": await self.close() diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index d0ef46f2..3c343652 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -611,7 +611,7 @@ async def connect(self) -> None: log.debug("connect: done") def _remove_listener(self, query_id) -> None: - """After exiting from a subscription, remove the listener and + """After exiting from a subscription, remove the listener and signal an event if this was the last listener for the client. """ if query_id in self.listeners: diff --git a/tests/test_phoenix_channel_subscription.py b/tests/test_phoenix_channel_subscription.py index 1e2c2109..537e4979 100644 --- a/tests/test_phoenix_channel_subscription.py +++ b/tests/test_phoenix_channel_subscription.py @@ -1,5 +1,6 @@ import asyncio import json +import sys import pytest from parse import search @@ -13,71 +14,73 @@ test_subscription_id = "test_subscription" # A server should send this after receiving a 'phx_leave' request message. -# 'query_id' should be the value of the 'ref' in the 'phx_leave' request. +# 'query_id' should be the value of the 'ref' in the 'phx_leave' request. # With only one listener, the transport is closed automatically when # it exits a subscription, so this is not used in current tests. channel_leave_reply_template = ( - '{{' - '"topic":"{channel_name}",' - '"event":"phx_reply",' - '"payload":{{' - '"response":{{}},' - '"status":"ok"' - '}},' - '"ref":{query_id}' - '}}' + "{{" + '"topic":"{channel_name}",' + '"event":"phx_reply",' + '"payload":{{' + '"response":{{}},' + '"status":"ok"' + "}}," + '"ref":{query_id}' + "}}" ) # A server should send this after sending the 'channel_leave_reply' -# above, to confirm to the client that the channel was actually closed. +# above, to confirm to the client that the channel was actually closed. # With only one listener, the transport is closed automatically when # it exits a subscription, so this is not used in current tests. channel_close_reply_template = ( - '{{' - '"topic":"{channel_name}",' - '"event":"phx_close",' - '"payload":{{}},' - '"ref":null' - '}}' + "{{" + '"topic":"{channel_name}",' + '"event":"phx_close",' + '"payload":{{}},' + '"ref":null' + "}}" ) # A server sends this when it receives a 'subscribe' request, -# after creating a unique subscription id. 'query_id' should be the -# value of the 'ref' in the 'subscribe' request. +# after creating a unique subscription id. 'query_id' should be the +# value of the 'ref' in the 'subscribe' request. subscription_reply_template = ( - '{{' - '"topic":"{channel_name}",' - '"event":"phx_reply",' - '"payload":{{' - '"response":{{' - '"subscriptionId":"{subscription_id}"' - '}},' - '"status":"ok"' - '}},' - '"ref":{query_id}' - '}}' + "{{" + '"topic":"{channel_name}",' + '"event":"phx_reply",' + '"payload":{{' + '"response":{{' + '"subscriptionId":"{subscription_id}"' + "}}," + '"status":"ok"' + "}}," + '"ref":{query_id}' + "}}" ) countdown_data_template = ( - '{{' - '"topic":"{subscription_id}",' - '"event":"subscription:data",' - '"payload":{{' - '"subscriptionId":"{subscription_id}",' - '"result":{{' - '"data":{{' - '"countdown":{{' - '"number":{number}' - '}}' - '}}' - '}}' - '}},' - '"ref":null' - '}}' + "{{" + '"topic":"{subscription_id}",' + '"event":"subscription:data",' + '"payload":{{' + '"subscriptionId":"{subscription_id}",' + '"result":{{' + '"data":{{' + '"countdown":{{' + '"number":{number}' + "}}" + "}}" + "}}" + "}}," + '"ref":null' + "}}" ) + async def server_countdown(ws, path): import websockets + from .conftest import MS, PhoenixChannelServerHelper try: @@ -95,17 +98,20 @@ async def server_countdown(ws, path): count = count_found[0] print(f"Countdown started from: {count}") - await ws.send(subscription_reply_template.format( - subscription_id=test_subscription_id, - channel_name=channel_name, - query_id=query_id)) + await ws.send( + subscription_reply_template.format( + subscription_id=test_subscription_id, + channel_name=channel_name, + query_id=query_id, + ) + ) async def counting_coro(): for number in range(count, -1, -1): await ws.send( countdown_data_template.format( - subscription_id=test_subscription_id, - number=number) + subscription_id=test_subscription_id, number=number + ) ) await asyncio.sleep(2 * MS) @@ -120,13 +126,18 @@ async def stopping_coro(): if json_result["event"] == "unsubscribe": query_id = json_result["ref"] payload = json_result["payload"] - if payload["subscriptionId"] == test_subscription_id: - print("Sending unsubscribe reply") - counting_task.cancel() - await ws.send(subscription_reply_template.format( - subscription_id=test_subscription_id, + subscription_id = payload["subscriptionId"] + assert subscription_id == test_subscription_id + + print("Sending unsubscribe reply") + counting_task.cancel() + await ws.send( + subscription_reply_template.format( + subscription_id=subscription_id, channel_name=channel_name, - query_id=query_id)) + query_id=query_id, + ) + ) stopping_task = asyncio.ensure_future(stopping_coro()) @@ -148,6 +159,7 @@ async def stopping_coro(): finally: await ws.wait_closed() + countdown_subscription_str = """ subscription {{ countdown (count: {count}) {{ @@ -161,17 +173,21 @@ async def stopping_coro(): @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) @pytest.mark.parametrize("end_count", [0, 5]) -async def test_phoenix_channel_subscription(event_loop, server, subscription_str, end_count): - """Parameterized test. +async def test_phoenix_channel_subscription( + event_loop, server, subscription_str, end_count +): + """Parameterized test. :param end_count: Target count at which the test will 'break' to unsubscribe. """ import logging + from gql.transport.phoenix_channel_websockets import ( PhoenixChannelWebsocketsTransport, ) - from gql.transport.websockets import log as websockets_logger from gql.transport.phoenix_channel_websockets import log as phoenix_logger + from gql.transport.websockets import log as websockets_logger + websockets_logger.setLevel(logging.DEBUG) phoenix_logger.setLevel(logging.DEBUG) @@ -191,32 +207,40 @@ async def test_phoenix_channel_subscription(event_loop, server, subscription_str assert number == count if number == end_count: - # Breaking will unsubscribe + # Note: we need to run generator.aclose() here or the finally block in + # the subscribe will not be reached in pypy3 (python version 3.6.1) + # In more recent versions, 'break' will trigger __aexit__. + if sys.version_info < (3, 7): + await session._generator.aclose() break + count -= 1 assert count == end_count heartbeat_data_template = ( - '{{' - '"topic":"{subscription_id}",' - '"event":"subscription:data",' - '"payload":{{' - '"subscriptionId":"{subscription_id}",' - '"result":{{' - '"data":{{' - '"heartbeat":{{' - '"heartbeat_count":{count}' - '}}' - '}}' - '}}' - '}},' - '"ref":null' - '}}' + "{{" + '"topic":"{subscription_id}",' + '"event":"subscription:data",' + '"payload":{{' + '"subscriptionId":"{subscription_id}",' + '"result":{{' + '"data":{{' + '"heartbeat":{{' + '"heartbeat_count":{count}' + "}}" + "}}" + "}}" + "}}," + '"ref":null' + "}}" ) + async def phoenix_heartbeat_server(ws, path): + import websockets + from .conftest import PhoenixChannelServerHelper try: @@ -228,10 +252,13 @@ async def phoenix_heartbeat_server(ws, path): channel_name = json_result["topic"] query_id = json_result["ref"] - await ws.send(subscription_reply_template.format( - subscription_id=test_subscription_id, - channel_name=channel_name, - query_id=query_id)) + await ws.send( + subscription_reply_template.format( + subscription_id=test_subscription_id, + channel_name=channel_name, + query_id=query_id, + ) + ) async def heartbeat_coro(): i = 0 @@ -239,19 +266,26 @@ async def heartbeat_coro(): heartbeat_result = await ws.recv() json_result = json.loads(heartbeat_result) if json_result["event"] == "heartbeat": - await ws.send(heartbeat_data_template.format( - subscription_id=test_subscription_id, - count=i)) + await ws.send( + heartbeat_data_template.format( + subscription_id=test_subscription_id, count=i + ) + ) i = i + 1 elif json_result["event"] == "unsubscribe": query_id = json_result["ref"] payload = json_result["payload"] - if payload["subscriptionId"] == test_subscription_id: - print("Sending unsubscribe reply") - await ws.send(subscription_reply_template.format( - subscription_id=test_subscription_id, + subscription_id = payload["subscriptionId"] + assert subscription_id == test_subscription_id + + print("Sending unsubscribe reply") + await ws.send( + subscription_reply_template.format( + subscription_id=subscription_id, channel_name=channel_name, - query_id=query_id)) + query_id=query_id, + ) + ) await asyncio.wait_for(heartbeat_coro(), 60) # await PhoenixChannelServerHelper.send_close(ws) @@ -260,6 +294,7 @@ async def heartbeat_coro(): finally: await ws.wait_closed() + heartbeat_subscription_str = """ subscription { heartbeat { @@ -292,6 +327,11 @@ async def test_phoenix_channel_heartbeat(event_loop, server, subscription_str): assert heartbeat_count == i if heartbeat_count == 5: - # Breaking will unsubscribe + # Note: we need to run generator.aclose() here or the finally block in + # the subscribe will not be reached in pypy3 (python version 3.6.1) + # In more recent versions, 'break' will trigger __aexit__. + if sys.version_info < (3, 7): + await session._generator.aclose() break + i += 1 From 01f4a45998d4fc01d7663da404813dac0f8d70e3 Mon Sep 17 00:00:00 2001 From: Peter Zingg Date: Sun, 15 Aug 2021 22:28:32 -0700 Subject: [PATCH 10/24] Fix phoenix parse_answer and subscriptions --- gql/transport/phoenix_channel_websockets.py | 147 +++++++++++++------- tests/conftest.py | 7 +- 2 files changed, 103 insertions(+), 51 deletions(-) diff --git a/gql/transport/phoenix_channel_websockets.py b/gql/transport/phoenix_channel_websockets.py index 0313f559..c990d045 100644 --- a/gql/transport/phoenix_channel_websockets.py +++ b/gql/transport/phoenix_channel_websockets.py @@ -16,6 +16,14 @@ log = logging.getLogger(__name__) +class Subscription: + """Records listener_id and unsubscribe query_id for a subscription.""" + + def __init__(self, query_id: int) -> None: + self.listener_id: int = query_id + self.unsubscribe_id: Optional[int] = None + + class PhoenixChannelWebsocketsTransport(WebsocketsTransport): """The PhoenixChannelWebsocketsTransport is an **EXPERIMENTAL** async transport which allows you to execute queries and subscriptions against an `Absinthe`_ @@ -43,8 +51,7 @@ def __init__( self.channel_name: str = channel_name self.heartbeat_interval: float = heartbeat_interval self.heartbeat_task: Optional[asyncio.Future] = None - self.subscription_ids_to_query_ids: Dict[str, int] = {} - self.unsub_to_listener_query_ids: Dict[int, int] = {} + self.subscriptions: Dict[str, Subscription] = {} super(PhoenixChannelWebsocketsTransport, self).__init__(*args, **kwargs) async def _send_init_message_and_wait_ack(self) -> None: @@ -99,35 +106,26 @@ async def heartbeat_coro(): self.heartbeat_task = asyncio.ensure_future(heartbeat_coro()) - async def _send_stop_message(self, listener_query_id: int) -> None: + async def _send_stop_message(self, query_id: int) -> None: """Send an 'unsubscribe' message to the Phoenix Channel referencing the listener's query_id, saving the query_id of the message. The server should afterwards return a 'phx_reply' message with the same query_id and subscription_id of the 'unsubscribe' request. """ - query_id = self.next_query_id - self.next_query_id += 1 - - subscription_id = None - for sub_id, q_id in self.subscription_ids_to_query_ids.items(): - if q_id == listener_query_id: - subscription_id = sub_id - break + subscription_id = self._find_subscription_id(query_id) - if subscription_id is None: - raise ValueError( - f"No subscription for {listener_query_id}" - ) # pragma: no cover + unsubscribe_query_id = self.next_query_id + self.next_query_id += 1 # Save the ref so it can be matched in the reply - self.unsub_to_listener_query_ids[query_id] = listener_query_id + self.subscriptions[subscription_id].unsubscribe_id = unsubscribe_query_id unsubscribe_message = json.dumps( { "topic": self.channel_name, "event": "unsubscribe", "payload": {"subscriptionId": subscription_id}, - "ref": query_id, + "ref": unsubscribe_query_id, } ) @@ -212,7 +210,7 @@ def _parse_answer( subscription_id = str(payload.get("subscriptionId")) try: - answer_id = self.subscription_ids_to_query_ids[subscription_id] + answer_id = self.subscriptions[subscription_id].listener_id except KeyError: raise ValueError( f"subscription '{subscription_id}' has not been registered" @@ -226,13 +224,20 @@ def _parse_answer( answer_type = "data" execution_result = ExecutionResult( - errors=payload.get("errors"), data=result.get("data"), - extensions=payload.get("extensions"), + errors=result.get("errors"), + extensions=result.get("extensions"), ) elif event == "phx_reply": - answer_id = int(json_answer.get("ref")) + try: + answer_id = int(json_answer.get("ref")) + except Exception: + raise ValueError("ref is not an integer") + + if answer_id == 0: + raise ValueError("ref is zero") + payload = json_answer.get("payload") if not isinstance(payload, dict): @@ -240,36 +245,30 @@ def _parse_answer( status = str(payload.get("status")) - # Unsubscription reply? - listener_query_id = self.unsub_to_listener_query_ids.pop( - answer_id, None - ) - if status == "ok": - + # join or leave answer answer_type = "reply" + response = payload.get("response") - if isinstance(response, dict) and "subscriptionId" in response: - subscription_id = str(response.get("subscriptionId")) - if listener_query_id is not None: - - answer_id = listener_query_id - answer_type = "unsubscribe" - - if ( - self.subscription_ids_to_query_ids.get(subscription_id) - != listener_query_id - ): - raise ValueError( - f"Listener {listener_query_id} " - "in unsubscribe reply does not exist" - ) - else: - # Subscription reply - self.subscription_ids_to_query_ids[ - subscription_id - ] = answer_id + if isinstance(response, dict): + if "data" in response: + # query or mutation result answer + answer_type = "data" + + execution_result = ExecutionResult( + data=response["data"], + errors=response.get("errors"), + extensions=response.get("extensions"), + ) + elif "subscriptionId" in response: + # subscribe or unsubscribe answer + subscription_id = str(response.get("subscriptionId")) + + # answer_type is "reply" or "unsubscribe" + answer_type, answer_id = self._parse_subscription_answer( + answer_id, subscription_id + ) elif status == "error": response = payload.get("response") @@ -289,6 +288,9 @@ def _parse_answer( raise TransportQueryError("reply timeout", query_id=answer_id) elif event == "phx_error": + # Sent if the channel has crashed + # answer_id will be the "join_ref" for the channel + # answer_id = int(json_answer.get("ref")) raise TransportServerError("Server error") elif event == "phx_close": answer_type = "close" @@ -303,21 +305,66 @@ def _parse_answer( return answer_type, answer_id, execution_result + def _parse_subscription_answer( + self, answer_id: int, subscription_id: str + ) -> Tuple[str, int]: + if subscription_id in self.subscriptions: + # Unsubscribe reply + subscription = self.subscriptions[subscription_id] + unsubscribe_id = subscription.unsubscribe_id + if unsubscribe_id is None or unsubscribe_id != answer_id: + raise TransportProtocolError( + "Unsubscription reply does not match an request" + ) + + answer_id = subscription.listener_id + answer_type = "unsubscribe" + else: + # Subscription reply + self.subscriptions[subscription_id] = Subscription(answer_id) + + # Confirm unsubscribe for subscriptions + self.listeners[answer_id].send_stop = True + + answer_type = "reply" + + return answer_type, answer_id + async def _handle_answer( self, answer_type: str, answer_id: Optional[int], execution_result: Optional[ExecutionResult], ) -> None: - if answer_type == "unsubscribe": + if answer_type == "close": + await self.close() + elif answer_type == "unsubscribe": # Remove the listener here, to possibly signal # that it is the last listener in the session. self._remove_listener(answer_id) - elif answer_type == "close": - await self.close() else: await super()._handle_answer(answer_type, answer_id, execution_result) + def _remove_listener(self, query_id) -> None: + try: + subscription_id = self._find_subscription_id(query_id) + del self.subscriptions[subscription_id] + except Exception: + pass + super()._remove_listener(query_id) + + def _find_subscription_id(self, query_id: int) -> str: + """Perform a reverse lookup to find the subscription id matching + a listener's query_id. + """ + for subscription_id, subscription in self.subscriptions.items(): + if subscription.listener_id == query_id: + return subscription_id + + raise TransportProtocolError( + f"No subscription matches listener {query_id}" + ) # pragma: no cover + async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: if self.heartbeat_task is not None: self.heartbeat_task.cancel() diff --git a/tests/conftest.py b/tests/conftest.py index 62f107ac..ed57fdd3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -100,7 +100,12 @@ async def go(app, *, port=None, **kwargs): # type: ignore # Adding debug logs to websocket tests -for name in ["websockets.legacy.server", "gql.transport.websockets", "gql.dsl"]: +for name in [ + "websockets.legacy.server", + "gql.transport.websockets", + "gql.transport.phoenix_channel_websockets", + "gql.dsl", +]: logger = logging.getLogger(name) logger.setLevel(logging.DEBUG) From fd0b32f46d7e3e4ef6f005f36c7cf255147a9d0c Mon Sep 17 00:00:00 2001 From: Peter Zingg Date: Mon, 16 Aug 2021 07:40:49 -0700 Subject: [PATCH 11/24] parse responses per graphql spec --- gql/transport/phoenix_channel_websockets.py | 98 +++++++++++---------- 1 file changed, 50 insertions(+), 48 deletions(-) diff --git a/gql/transport/phoenix_channel_websockets.py b/gql/transport/phoenix_channel_websockets.py index c990d045..fbd5e6c6 100644 --- a/gql/transport/phoenix_channel_websockets.py +++ b/gql/transport/phoenix_channel_websockets.py @@ -196,6 +196,15 @@ def _parse_answer( answer_id: Optional[int] = None answer_type: str = "" execution_result: Optional[ExecutionResult] = None + subscription_id: Optional[str] = None + + def _get_existing_subscription(d: dict) -> Subscription: + subscription_id = d.get("subscriptionId") + if subscription_id is not None: + subscription_id = str(subscription_id) + if subscription_id is None or subscription_id not in self.subscriptions: + raise ValueError("unregistered subscription id") + return self.subscriptions[subscription_id] try: json_answer = json.loads(answer) @@ -208,21 +217,16 @@ def _parse_answer( if not isinstance(payload, dict): raise ValueError("payload is not a dict") - subscription_id = str(payload.get("subscriptionId")) - try: - answer_id = self.subscriptions[subscription_id].listener_id - except KeyError: - raise ValueError( - f"subscription '{subscription_id}' has not been registered" - ) + subscription = _get_existing_subscription(payload) result = payload.get("result") - if not isinstance(result, dict): raise ValueError("result is not a dict") answer_type = "data" + answer_id = subscription.listener_id + execution_result = ExecutionResult( data=result.get("data"), errors=result.get("errors"), @@ -246,29 +250,52 @@ def _parse_answer( status = str(payload.get("status")) if status == "ok": - # join or leave answer + response = payload.get("response") + answer_type = "reply" - response = payload.get("response") + if answer_id in self.listeners: + # Query, mutation or subscription answer + if not isinstance(response, dict): + raise ValueError("response is not a dict") + + subscription_id = response.get("subscriptionId") + if subscription_id is not None: + # Subscription answer + subscription_id = str(subscription_id) + self.subscriptions[subscription_id] = Subscription( + answer_id + ) + + # Confirm unsubscribe for subscriptions + self.listeners[answer_id].send_stop = True + else: + # Query or mutation answer + # GraphQL spec says only three keys are permitted + keys = set(response.keys()) + invalid = keys - {"data", "errors", "extensions"} + if len(invalid) > 0: + raise ValueError( + "response contains invalid items: " + + ", ".join(invalid) + ) - if isinstance(response, dict): - if "data" in response: - # query or mutation result answer answer_type = "data" execution_result = ExecutionResult( - data=response["data"], + data=response.get("data"), errors=response.get("errors"), extensions=response.get("extensions"), ) - elif "subscriptionId" in response: - # subscribe or unsubscribe answer - subscription_id = str(response.get("subscriptionId")) - # answer_type is "reply" or "unsubscribe" - answer_type, answer_id = self._parse_subscription_answer( - answer_id, subscription_id - ) + elif isinstance(response, dict) and "subscriptionId" in response: + subscription = _get_existing_subscription(response) + if answer_id != subscription.unsubscribe_id: + raise ValueError("invalid unsubscribe answer id") + + answer_type = "unsubscribe" + + answer_id = subscription.listener_id elif status == "error": response = payload.get("response") @@ -295,7 +322,7 @@ def _parse_answer( elif event == "phx_close": answer_type = "close" else: - raise ValueError + raise ValueError("unrecognized event") except ValueError as e: log.error(f"Error parsing answer '{answer}' " + repr(e)) @@ -305,31 +332,6 @@ def _parse_answer( return answer_type, answer_id, execution_result - def _parse_subscription_answer( - self, answer_id: int, subscription_id: str - ) -> Tuple[str, int]: - if subscription_id in self.subscriptions: - # Unsubscribe reply - subscription = self.subscriptions[subscription_id] - unsubscribe_id = subscription.unsubscribe_id - if unsubscribe_id is None or unsubscribe_id != answer_id: - raise TransportProtocolError( - "Unsubscription reply does not match an request" - ) - - answer_id = subscription.listener_id - answer_type = "unsubscribe" - else: - # Subscription reply - self.subscriptions[subscription_id] = Subscription(answer_id) - - # Confirm unsubscribe for subscriptions - self.listeners[answer_id].send_stop = True - - answer_type = "reply" - - return answer_type, answer_id - async def _handle_answer( self, answer_type: str, @@ -362,7 +364,7 @@ def _find_subscription_id(self, query_id: int) -> str: return subscription_id raise TransportProtocolError( - f"No subscription matches listener {query_id}" + f"No subscription registered for listener {query_id}" ) # pragma: no cover async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: From ca242a9ab9d211341b53b3cd7638550525bfb664 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Mon, 16 Aug 2021 17:40:51 +0200 Subject: [PATCH 12/24] Refactor unsubscription reply --- gql/transport/phoenix_channel_websockets.py | 43 +++++++++++++-------- 1 file changed, 26 insertions(+), 17 deletions(-) diff --git a/gql/transport/phoenix_channel_websockets.py b/gql/transport/phoenix_channel_websockets.py index 0313f559..8b8f82b3 100644 --- a/gql/transport/phoenix_channel_websockets.py +++ b/gql/transport/phoenix_channel_websockets.py @@ -240,11 +240,6 @@ def _parse_answer( status = str(payload.get("status")) - # Unsubscription reply? - listener_query_id = self.unsub_to_listener_query_ids.pop( - answer_id, None - ) - if status == "ok": answer_type = "reply" @@ -252,25 +247,39 @@ def _parse_answer( if isinstance(response, dict) and "subscriptionId" in response: subscription_id = str(response.get("subscriptionId")) - if listener_query_id is not None: - answer_id = listener_query_id - answer_type = "unsubscribe" + listener_query_id = self.unsub_to_listener_query_ids.pop( + answer_id, None + ) - if ( - self.subscription_ids_to_query_ids.get(subscription_id) - != listener_query_id - ): - raise ValueError( - f"Listener {listener_query_id} " - "in unsubscribe reply does not exist" - ) - else: + if listener_query_id is None: # Subscription reply + self.subscription_ids_to_query_ids[ subscription_id ] = answer_id + else: + + expected_id = self.subscription_ids_to_query_ids.get( + subscription_id + ) + + if listener_query_id == expected_id: + # Unsubscription reply + + answer_id = listener_query_id + answer_type = "unsubscribe" + + else: + + raise ValueError( + "Unexpected listener_query_id for " + f"subscriptionId={subscription_id}. " + f"Expected={expected_id} " + f"Received={listener_query_id} " + ) + elif status == "error": response = payload.get("response") From 7c2f85efe29120596fbebed4e8f8239709364524 Mon Sep 17 00:00:00 2001 From: Peter Zingg Date: Mon, 16 Aug 2021 16:30:04 -0700 Subject: [PATCH 13/24] Note the operation type for each listener --- gql/transport/websockets.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index 3c343652..02dce37d 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -6,7 +6,8 @@ from typing import Any, AsyncGenerator, Dict, Optional, Tuple, Union, cast import websockets -from graphql import DocumentNode, ExecutionResult, print_ast +from graphql import DocumentNode, ExecutionResult, OperationType, print_ast +from graphql.utilities import get_operation_ast from websockets.client import WebSocketClientProtocol from websockets.datastructures import HeadersLike from websockets.exceptions import ConnectionClosed @@ -34,8 +35,17 @@ class ListenerQueue: to the consumer once all the previous messages have been consumed from the queue """ - def __init__(self, query_id: int, send_stop: bool) -> None: + def __init__( + self, query_id: int, operation_type: Optional[OperationType], send_stop: bool + ) -> None: + operation: str = "query" + if operation_type is not None: + if operation_type == OperationType.MUTATION: # pragma: no cover + operation = "mutation" + elif operation_type == OperationType.SUBSCRIPTION: + operation = "subscription" self.query_id: int = query_id + self.operation: str = operation self.send_stop: bool = send_stop self._queue: asyncio.Queue = asyncio.Queue() self._closed: bool = False @@ -458,7 +468,14 @@ async def subscribe( ) # Create a queue to receive the answers for this query_id - listener = ListenerQueue(query_id, send_stop=(send_stop is True)) + operation_type = None + operation = get_operation_ast(document, operation_name) + if operation is not None: + operation_type = operation.operation + + listener = ListenerQueue( + query_id, operation_type, send_stop=(send_stop is True) + ) self.listeners[query_id] = listener # We will need to wait at close for this query to clean properly From f35de0367c6512630a34c36d886fa2291467d13f Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Mon, 16 Aug 2021 17:42:22 +0200 Subject: [PATCH 14/24] Make phoenix behavior more similar to the websockets code Now an unsubscribe reply will generate a 'complete' message to stop the async generator cleanly --- gql/transport/phoenix_channel_websockets.py | 22 +++++++++++++-------- gql/transport/websockets.py | 5 +++-- tests/conftest.py | 9 +++++++-- tests/test_phoenix_channel_subscription.py | 12 ++++++----- 4 files changed, 31 insertions(+), 17 deletions(-) diff --git a/gql/transport/phoenix_channel_websockets.py b/gql/transport/phoenix_channel_websockets.py index 8b8f82b3..94b7ada5 100644 --- a/gql/transport/phoenix_channel_websockets.py +++ b/gql/transport/phoenix_channel_websockets.py @@ -189,7 +189,7 @@ def _parse_answer( Returns a list consisting of: - the answer_type (between: - 'heartbeat', 'data', 'reply', 'error', 'unsubscribe') + 'heartbeat', 'data', 'reply', 'error', 'complete', 'close') - the answer id (Integer) if received or None - an execution Result if the answer_type is 'data' or None """ @@ -259,6 +259,11 @@ def _parse_answer( subscription_id ] = answer_id + log.debug( + f"Saving query_id={answer_id} for " + f"subscription_id={subscription_id}" + ) + else: expected_id = self.subscription_ids_to_query_ids.get( @@ -269,7 +274,12 @@ def _parse_answer( # Unsubscription reply answer_id = listener_query_id - answer_type = "unsubscribe" + answer_type = "complete" + + log.debug( + f"Sending complete with query_id={answer_id} " + f"for subscription_id={subscription_id}" + ) else: @@ -305,7 +315,7 @@ def _parse_answer( raise ValueError except ValueError as e: - log.error(f"Error parsing answer '{answer}' " + repr(e)) + log.error(f"Error parsing answer '{answer}': {e!r}") raise TransportProtocolError( "Server did not return a GraphQL result" ) from e @@ -318,11 +328,7 @@ async def _handle_answer( answer_id: Optional[int], execution_result: Optional[ExecutionResult], ) -> None: - if answer_type == "unsubscribe": - # Remove the listener here, to possibly signal - # that it is the last listener in the session. - self._remove_listener(answer_id) - elif answer_type == "close": + if answer_type == "close": await self.close() else: await super()._handle_answer(answer_type, answer_id, execution_result) diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index 3c343652..2e9db3c4 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -486,12 +486,13 @@ async def subscribe( break except (asyncio.CancelledError, GeneratorExit, TransportClosed) as e: - log.debug("Exception in subscribe: " + repr(e)) + log.debug(f"Exception in subscribe: {e!r}") if listener.send_stop: await self._send_stop_message(query_id) listener.send_stop = False finally: + log.debug(f"In subscribe finally for query_id {query_id}") self._remove_listener(query_id) async def execute( @@ -641,7 +642,7 @@ async def _clean_close(self, e: Exception) -> None: try: await asyncio.wait_for(self._no_more_listeners.wait(), self.close_timeout) except asyncio.TimeoutError: # pragma: no cover - pass + log.debug("Timer close_timeout fired") # Finally send the 'connection_terminate' message await self._send_connection_terminate_message() diff --git a/tests/conftest.py b/tests/conftest.py index 62f107ac..df69c121 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -100,7 +100,12 @@ async def go(app, *, port=None, **kwargs): # type: ignore # Adding debug logs to websocket tests -for name in ["websockets.legacy.server", "gql.transport.websockets", "gql.dsl"]: +for name in [ + "websockets.legacy.server", + "gql.transport.websockets", + "gql.transport.phoenix_channel_websockets", + "gql.dsl", +]: logger = logging.getLogger(name) logger.setLevel(logging.DEBUG) @@ -170,7 +175,7 @@ async def stop(self): self.server.close() try: - await asyncio.wait_for(self.server.wait_closed(), timeout=1) + await asyncio.wait_for(self.server.wait_closed(), timeout=5) except asyncio.TimeoutError: # pragma: no cover assert False, "Server failed to stop" diff --git a/tests/test_phoenix_channel_subscription.py b/tests/test_phoenix_channel_subscription.py index 537e4979..38ac1f67 100644 --- a/tests/test_phoenix_channel_subscription.py +++ b/tests/test_phoenix_channel_subscription.py @@ -130,7 +130,6 @@ async def stopping_coro(): assert subscription_id == test_subscription_id print("Sending unsubscribe reply") - counting_task.cancel() await ws.send( subscription_reply_template.format( subscription_id=subscription_id, @@ -138,6 +137,7 @@ async def stopping_coro(): query_id=query_id, ) ) + counting_task.cancel() stopping_task = asyncio.ensure_future(stopping_coro()) @@ -146,12 +146,13 @@ async def stopping_coro(): except asyncio.CancelledError: print("Now counting task is cancelled") - stopping_task.cancel() - + # Waiting for a clean stop try: - await stopping_task + await asyncio.wait_for(stopping_task, 3) except asyncio.CancelledError: print("Now stopping task is cancelled") + except asyncio.TimeoutError: + print("Now stopping task is in timeout") # await PhoenixChannelServerHelper.send_close(ws) except websockets.exceptions.ConnectionClosedOK: @@ -194,7 +195,7 @@ async def test_phoenix_channel_subscription( path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" sample_transport = PhoenixChannelWebsocketsTransport( - channel_name=test_channel, url=url + channel_name=test_channel, url=url, close_timeout=5 ) count = 10 @@ -212,6 +213,7 @@ async def test_phoenix_channel_subscription( # In more recent versions, 'break' will trigger __aexit__. if sys.version_info < (3, 7): await session._generator.aclose() + print("break") break count -= 1 From 7f685bc80ccfba6dd52c1b1d79fc080ea811ee31 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Tue, 17 Aug 2021 01:28:30 +0200 Subject: [PATCH 15/24] Modifying exception to error log --- gql/transport/phoenix_channel_websockets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gql/transport/phoenix_channel_websockets.py b/gql/transport/phoenix_channel_websockets.py index 94b7ada5..c0449b64 100644 --- a/gql/transport/phoenix_channel_websockets.py +++ b/gql/transport/phoenix_channel_websockets.py @@ -283,7 +283,7 @@ def _parse_answer( else: - raise ValueError( + log.error( "Unexpected listener_query_id for " f"subscriptionId={subscription_id}. " f"Expected={expected_id} " From 1f991d6ec18bd7717091b78e866c3e0a35e73f33 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Tue, 17 Aug 2021 01:29:30 +0200 Subject: [PATCH 16/24] Faster phoenix heartbeat tests --- tests/test_phoenix_channel_subscription.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_phoenix_channel_subscription.py b/tests/test_phoenix_channel_subscription.py index 38ac1f67..c1f2f8a2 100644 --- a/tests/test_phoenix_channel_subscription.py +++ b/tests/test_phoenix_channel_subscription.py @@ -317,7 +317,7 @@ async def test_phoenix_channel_heartbeat(event_loop, server, subscription_str): path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" sample_transport = PhoenixChannelWebsocketsTransport( - channel_name=test_channel, url=url, heartbeat_interval=1 + channel_name=test_channel, url=url, heartbeat_interval=0.1 ) subscription = gql(heartbeat_subscription_str) From c57019a934b5f44612acbb05a82200868231b255 Mon Sep 17 00:00:00 2001 From: Peter Zingg Date: Mon, 16 Aug 2021 16:34:24 -0700 Subject: [PATCH 17/24] Make parse_answer more compliant with specs, better subscription management --- gql/transport/phoenix_channel_websockets.py | 160 +++++++++++++------- 1 file changed, 102 insertions(+), 58 deletions(-) diff --git a/gql/transport/phoenix_channel_websockets.py b/gql/transport/phoenix_channel_websockets.py index fbd5e6c6..b9ba6442 100644 --- a/gql/transport/phoenix_channel_websockets.py +++ b/gql/transport/phoenix_channel_websockets.py @@ -113,7 +113,7 @@ async def _send_stop_message(self, query_id: int) -> None: The server should afterwards return a 'phx_reply' message with the same query_id and subscription_id of the 'unsubscribe' request. """ - subscription_id = self._find_subscription_id(query_id) + subscription_id = self._find_existing_subscription(query_id) unsubscribe_query_id = self.next_query_id self.next_query_id += 1 @@ -198,33 +198,62 @@ def _parse_answer( execution_result: Optional[ExecutionResult] = None subscription_id: Optional[str] = None - def _get_existing_subscription(d: dict) -> Subscription: - subscription_id = d.get("subscriptionId") - if subscription_id is not None: - subscription_id = str(subscription_id) - if subscription_id is None or subscription_id not in self.subscriptions: - raise ValueError("unregistered subscription id") - return self.subscriptions[subscription_id] + def _get_value(d: Any, key: str, label: str) -> Any: + if not isinstance(d, dict): + raise ValueError(f"{label} is not a dict") + + return d.get(key) + + def _required_value(d: Any, key: str, label: str) -> Any: + value = _get_value(d, key, label) + if value is None: + raise ValueError(f"null {key} in {label}") + + return value + + def _required_subscription_id( + d: Any, label: str, must_exist: bool = False, must_not_exist=False + ) -> str: + subscription_id = str(_required_value(d, "subscriptionId", label)) + if must_exist and (subscription_id not in self.subscriptions): + raise ValueError("unregistered subscriptionId") + if must_not_exist and (subscription_id in self.subscriptions): + raise ValueError("previously registered subscriptionId") + + return subscription_id + + def _validate_response(d: Any, label: str) -> dict: + """Make sure query or mutation answer conforms. + The GraphQL spec says only three keys are permitted. + """ + if not isinstance(d, dict): + raise ValueError(f"{label} is not a dict") + + keys = set(d.keys()) + invalid = keys - {"data", "errors", "extensions"} + if len(invalid) > 0: + raise ValueError( + f"{label} contains invalid items: " + ", ".join(invalid) + ) + return d try: json_answer = json.loads(answer) - event = str(json_answer.get("event")) + event = str(_required_value(json_answer, "event", "answer")) if event == "subscription:data": - payload = json_answer.get("payload") - - if not isinstance(payload, dict): - raise ValueError("payload is not a dict") + payload = _required_value(json_answer, "payload", "answer") - subscription = _get_existing_subscription(payload) + subscription_id = _required_subscription_id( + payload, "payload", must_exist=True + ) - result = payload.get("result") - if not isinstance(result, dict): - raise ValueError("result is not a dict") + result = _validate_response(payload.get("result"), "result") answer_type = "data" + subscription = self.subscriptions[subscription_id] answer_id = subscription.listener_id execution_result = ExecutionResult( @@ -235,34 +264,31 @@ def _get_existing_subscription(d: dict) -> Subscription: elif event == "phx_reply": try: - answer_id = int(json_answer.get("ref")) - except Exception: + answer_id = int(_required_value(json_answer, "ref", "answer")) + except ValueError: # pragma: no cover + raise + except Exception: # pragma: no cover raise ValueError("ref is not an integer") - - if answer_id == 0: + if answer_id == 0: # pragma: no cover raise ValueError("ref is zero") - payload = json_answer.get("payload") + payload = _required_value(json_answer, "payload", "answer") - if not isinstance(payload, dict): - raise ValueError("payload is not a dict") - - status = str(payload.get("status")) + status = _get_value(payload, "status", "payload") + if status is not None: + status = str(status) if status == "ok": - response = payload.get("response") - answer_type = "reply" if answer_id in self.listeners: - # Query, mutation or subscription answer - if not isinstance(response, dict): - raise ValueError("response is not a dict") - - subscription_id = response.get("subscriptionId") - if subscription_id is not None: + operation_type = self.listeners[answer_id].operation + if operation_type == "subscription": # Subscription answer - subscription_id = str(subscription_id) + response = _required_value(payload, "response", "payload") + subscription_id = _required_subscription_id( + response, "response", must_not_exist=True + ) self.subscriptions[subscription_id] = Subscription( answer_id ) @@ -272,13 +298,8 @@ def _get_existing_subscription(d: dict) -> Subscription: else: # Query or mutation answer # GraphQL spec says only three keys are permitted - keys = set(response.keys()) - invalid = keys - {"data", "errors", "extensions"} - if len(invalid) > 0: - raise ValueError( - "response contains invalid items: " - + ", ".join(invalid) - ) + response = _required_value(payload, "response", "payload") + response = _validate_response(response, "response") answer_type = "data" @@ -287,15 +308,23 @@ def _get_existing_subscription(d: dict) -> Subscription: errors=response.get("errors"), extensions=response.get("extensions"), ) + else: + ( + registered_subscription_id, + listener_id, + ) = self._find_subscription(answer_id) + if registered_subscription_id is not None: + # Unsubscription answer + response = _required_value(payload, "response", "payload") + subscription_id = _required_subscription_id( + response, "response" + ) + if subscription_id != registered_subscription_id: + raise ValueError("subscription id does not match") - elif isinstance(response, dict) and "subscriptionId" in response: - subscription = _get_existing_subscription(response) - if answer_id != subscription.unsubscribe_id: - raise ValueError("invalid unsubscribe answer id") - - answer_type = "unsubscribe" + answer_type = "unsubscribe" - answer_id = subscription.listener_id + answer_id = listener_id elif status == "error": response = payload.get("response") @@ -309,10 +338,13 @@ def _get_existing_subscription(d: dict) -> Subscription: raise TransportQueryError( str(response.get("reason")), query_id=answer_id ) - raise ValueError("reply error") + raise TransportQueryError("reply error", query_id=answer_id) elif status == "timeout": raise TransportQueryError("reply timeout", query_id=answer_id) + else: + # missing or unrecognized status, just continue + pass elif event == "phx_error": # Sent if the channel has crashed @@ -327,7 +359,7 @@ def _get_existing_subscription(d: dict) -> Subscription: except ValueError as e: log.error(f"Error parsing answer '{answer}' " + repr(e)) raise TransportProtocolError( - "Server did not return a GraphQL result" + "Server did not return a GraphQL result: " + str(e) ) from e return answer_type, answer_id, execution_result @@ -348,24 +380,36 @@ async def _handle_answer( await super()._handle_answer(answer_type, answer_id, execution_result) def _remove_listener(self, query_id) -> None: + """If the listener was a subscription, remove that information.""" try: - subscription_id = self._find_subscription_id(query_id) + subscription_id = self._find_existing_subscription(query_id) del self.subscriptions[subscription_id] except Exception: pass super()._remove_listener(query_id) - def _find_subscription_id(self, query_id: int) -> str: + def _find_subscription(self, query_id: int) -> Tuple[Optional[str], int]: """Perform a reverse lookup to find the subscription id matching a listener's query_id. """ for subscription_id, subscription in self.subscriptions.items(): - if subscription.listener_id == query_id: - return subscription_id + if query_id == subscription.listener_id: + return subscription_id, query_id + if query_id == subscription.unsubscribe_id: + return subscription_id, subscription.listener_id + return None, query_id + + def _find_existing_subscription(self, query_id: int) -> str: + """Perform a reverse lookup to find the subscription id matching + a listener's query_id. + """ + subscription_id, _listener_id = self._find_subscription(query_id) - raise TransportProtocolError( - f"No subscription registered for listener {query_id}" - ) # pragma: no cover + if subscription_id is None: + raise TransportProtocolError( + f"No subscription registered for listener {query_id}" + ) # pragma: no cover + return subscription_id async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: if self.heartbeat_task is not None: From e3788bc84353269ae96e022db9608138addb810a Mon Sep 17 00:00:00 2001 From: Peter Zingg Date: Mon, 16 Aug 2021 16:35:41 -0700 Subject: [PATCH 18/24] More testing, distinguish tests between queries and subscriptions --- tests/test_phoenix_channel_exceptions.py | 351 +++++++++++++++++++---- tests/test_phoenix_channel_query.py | 103 +++++-- 2 files changed, 388 insertions(+), 66 deletions(-) diff --git a/tests/test_phoenix_channel_exceptions.py b/tests/test_phoenix_channel_exceptions.py index 6f066325..fb843c64 100644 --- a/tests/test_phoenix_channel_exceptions.py +++ b/tests/test_phoenix_channel_exceptions.py @@ -10,6 +10,19 @@ # Marking all tests in this file with the websockets marker pytestmark = pytest.mark.websockets +# 210, 219, 234, 239, 272-275, 277, 327 + + +def ensure_list(s): + return ( + s + if s is None or isinstance(s, list) + else list(s) + if isinstance(s, tuple) + else [s] + ) + + query1_str = """ query getContinents { continents { @@ -19,21 +32,48 @@ } """ -default_subscription_server_answer = ( +default_query_server_answer = ( '{"event":"phx_reply",' '"payload":' '{"response":' - '{"subscriptionId":"test_subscription"},' + '{"data":{"continents":[' + '{"code":"AF","name":"Africa"},{"code":"AN","name":"Antarctica"},' + '{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},' + '{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},' + '{"code":"SA","name":"South America"}]}},' '"status":"ok"},' '"ref":2,' '"topic":"test_topic"}' ) + +# other protocol exceptions + +reply_ref_null_answer = ( + '{"event":"phx_reply","payload":{}', + '"ref":null,' '"topic":"test_topic"}', +) + +reply_ref_zero_answer = ( + '{"event":"phx_reply","payload":{}', + '"ref":0,' '"topic":"test_topic"}', +) + + +# "status":"error" responses + +generic_error_server_answer = ( + '{"event":"phx_reply",' + '"payload":' + '{"status":"error"},' + '"ref":2,' + '"topic":"test_topic"}' +) + error_with_reason_server_answer = ( '{"event":"phx_reply",' '"payload":' - '{"response":' - '{"reason":"internal error"},' + '{"response":{"reason":"internal error"},' '"status":"error"},' '"ref":2,' '"topic":"test_topic"}' @@ -42,8 +82,7 @@ multiple_errors_server_answer = ( '{"event":"phx_reply",' '"payload":' - '{"response":' - '{"errors": ["error 1", "error 2"]},' + '{"response":{"errors": ["error 1", "error 2"]},' '"status":"error"},' '"ref":2,' '"topic":"test_topic"}' @@ -57,31 +96,95 @@ '"topic":"test_topic"}' ) +invalid_payload_data_answer = ( + '{"event":"phx_reply",' '"payload":"INVALID",' '"ref":2,' '"topic":"test_topic"}' +) -def server( - query_server_answer, subscription_server_answer=default_subscription_server_answer, -): +# "status":"ok" exceptions + +invalid_response_server_answer = ( + '{"event":"phx_reply",' + '"payload":{"response":"INVALID",' + '"status":"ok"}' + '"ref":2,' + '"topic":"test_topic"}' +) + +invalid_response_keys_server_answer = ( + '{"event":"phx_reply",' + '"payload":{"response":' + '{"data":{"continents":null},"invalid":null}",' + '"status":"ok"}' + '"ref":2,' + '"topic":"test_topic"}' +) + +invalid_event_server_answer = '{"event":"unknown"}' + + +def query_server(server_answers=default_query_server_answer): from .conftest import PhoenixChannelServerHelper async def phoenix_server(ws, path): await PhoenixChannelServerHelper.send_connection_ack(ws) await ws.recv() - await ws.send(subscription_server_answer) - if query_server_answer is not None: - await ws.send(query_server_answer) + for server_answer in ensure_list(server_answers): + await ws.send(server_answer) await PhoenixChannelServerHelper.send_close(ws) await ws.wait_closed() return phoenix_server +async def no_connection_ack_phoenix_server(ws, path): + from .conftest import PhoenixChannelServerHelper + + await ws.recv() + await PhoenixChannelServerHelper.send_close(ws) + await ws.wait_closed() + + @pytest.mark.asyncio @pytest.mark.parametrize( "server", [ - server(error_with_reason_server_answer), - server(multiple_errors_server_answer), - server(timeout_server_answer), + query_server(reply_ref_null_answer), + query_server(reply_ref_zero_answer), + query_server(invalid_payload_data_answer), + query_server(invalid_response_server_answer), + query_server(invalid_response_keys_server_answer), + no_connection_ack_phoenix_server, + query_server(invalid_event_server_answer), + ], + indirect=True, +) +@pytest.mark.parametrize("query_str", [query1_str]) +async def test_phoenix_channel_query_protocol_error(event_loop, server, query_str): + + from gql.transport.phoenix_channel_websockets import ( + PhoenixChannelWebsocketsTransport, + ) + + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + sample_transport = PhoenixChannelWebsocketsTransport( + channel_name="test_channel", url=url + ) + + query = gql(query_str) + with pytest.raises(TransportProtocolError): + async with Client(transport=sample_transport) as session: + await session.execute(query) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "server", + [ + query_server(generic_error_server_answer), + query_server(error_with_reason_server_answer), + query_server(multiple_errors_server_answer), + query_server(timeout_server_answer), ], indirect=True, ) @@ -104,71 +207,188 @@ async def test_phoenix_channel_query_error(event_loop, server, query_str): await session.execute(query) -invalid_subscription_id_server_answer = ( +query2_str = """ + subscription getContinents { + continents { + code + name + } + } +""" + +default_subscription_server_answer = ( + '{"event":"phx_reply",' + '"payload":' + '{"response":{"subscriptionId":"test_subscription"},' + '"status":"ok"},' + '"ref":2,' + '"topic":"test_topic"}' +) + +missing_subscription_id_server_answer = ( + '{"event":"phx_reply",' + '"payload":' + '{"response":{},"status":"ok"},' + '"ref":2,' + '"topic":"test_topic"}' +) + +null_subscription_id_server_answer = ( + '{"event":"phx_reply",' + '"payload":' + '{"response":{"subscriptionId":null},"status":"ok"},' + '"ref":2,' + '"topic":"test_topic"}' +) + +default_subscription_data_answer = ( '{"event":"subscription:data","payload":' - '{"subscriptionId":"INVALID","result":' + '{"subscriptionId":"test_subscription","result":' '{"data":{"continents":[' '{"code":"AF","name":"Africa"},{"code":"AN","name":"Antarctica"},' '{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},' '{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},' '{"code":"SA","name":"South America"}]}}},' + '"ref":null,' + '"topic":"test_subscription"}' +) + +default_subscription_unsubscribe_answer = ( + '{"event":"phx_reply",' + '"payload":{"response":{"subscriptionId":"test_subscription"},' + '"status":"ok"},' '"ref":3,' '"topic":"test_topic"}' ) -invalid_payload_server_answer = ( +missing_subscription_id_data_answer = ( + '{"event":"subscription:data","payload":' + '{"result":' + '{"data":{"continents":[' + '{"code":"AF","name":"Africa"},{"code":"AN","name":"Antarctica"},' + '{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},' + '{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},' + '{"code":"SA","name":"South America"}]}}},' + '"ref":null,' + '"topic":"test_subscription"}' +) + +null_subscription_id_data_answer = ( + '{"event":"subscription:data","payload":' + '{"subscriptionId":null,"result":' + '{"data":{"continents":[' + '{"code":"AF","name":"Africa"},{"code":"AN","name":"Antarctica"},' + '{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},' + '{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},' + '{"code":"SA","name":"South America"}]}}},' + '"ref":null,' + '"topic":"test_subscription"}' +) + +invalid_subscription_id_data_answer = ( + '{"event":"subscription:data","payload":' + '{"subscriptionId":"INVALID","result":' + '{"data":{"continents":[' + '{"code":"AF","name":"Africa"},{"code":"AN","name":"Antarctica"},' + '{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},' + '{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},' + '{"code":"SA","name":"South America"}]}}},' + '"ref":null,' + '"topic":"test_subscription"}' +) + +invalid_payload_data_answer = ( '{"event":"subscription:data",' '"payload":"INVALID",' - '"ref":3,' - '"topic":"test_topic"}' + '"ref":null,' + '"topic":"test_subscription"}' ) -invalid_result_server_answer = ( +invalid_result_data_answer = ( '{"event":"subscription:data","payload":' - '{"subscriptionId":"test_subscription","result": "INVALID"},' - '"ref":3,' - '"topic":"test_topic"}' + '{"subscriptionId":"test_subscription","result":"INVALID"},' + '"ref":null,' + '"topic":"test_subscription"}' ) -generic_error_server_answer = ( +invalid_result_keys_data_answer = ( + '{"event":"subscription:data",' + '"payload":{"subscriptionId":"test_subscription",' + '"result":{"data":{"continents":null},"invalid":null}},' + '"ref":null,' + '"topic":"test_subscription"}' +) + +invalid_subscription_ref_answer = ( '{"event":"phx_reply",' - '"payload":' - '{"status":"error"},' - '"ref":2,' + '"payload":{"response":{"subscriptionId":"test_subscription"},' + '"status":"ok"},' + '"ref":99,' '"topic":"test_topic"}' ) -protocol_server_answer = '{"event":"unknown"}' - -invalid_payload_subscription_server_answer = ( - '{"event":"phx_reply", "payload":"INVALID", "ref":2, "topic":"test_topic"}' +mismatched_unsubscribe_answer = ( + '{"event":"phx_reply",' + '"payload":{"response":{"subscriptionId":"no_such_subscription"},' + '"status":"ok"},' + '"ref":3,' + '"topic":"test_topic"}' ) -async def no_connection_ack_phoenix_server(ws, path): +def subscription_server( + server_answers=default_subscription_server_answer, + data_answers=default_subscription_data_answer, + unsubscribe_answers=default_subscription_unsubscribe_answer, +): from .conftest import PhoenixChannelServerHelper + import json - await ws.recv() - await PhoenixChannelServerHelper.send_close(ws) - await ws.wait_closed() + async def phoenix_server(ws, path): + await PhoenixChannelServerHelper.send_connection_ack(ws) + await ws.recv() + if server_answers is not None: + for server_answer in ensure_list(server_answers): + await ws.send(server_answer) + if data_answers is not None: + for data_answer in ensure_list(data_answers): + await ws.send(data_answer) + if unsubscribe_answers is not None: + result = await ws.recv() + json_result = json.loads(result) + assert json_result["event"] == "unsubscribe" + for unsubscribe_answer in ensure_list(unsubscribe_answers): + await ws.send(unsubscribe_answer) + else: + await PhoenixChannelServerHelper.send_close(ws) + await ws.wait_closed() + + return phoenix_server @pytest.mark.asyncio @pytest.mark.parametrize( "server", [ - server(invalid_subscription_id_server_answer), - server(invalid_result_server_answer), - server(generic_error_server_answer), - no_connection_ack_phoenix_server, - server(protocol_server_answer), - server(invalid_payload_server_answer), - server(None, invalid_payload_subscription_server_answer), + subscription_server(invalid_subscription_ref_answer), + subscription_server(missing_subscription_id_server_answer), + subscription_server(null_subscription_id_server_answer), + subscription_server( + [default_subscription_server_answer, default_subscription_server_answer] + ), + subscription_server(data_answers=missing_subscription_id_data_answer), + subscription_server(data_answers=null_subscription_id_data_answer), + subscription_server(data_answers=invalid_subscription_id_data_answer), + subscription_server(data_answers=invalid_payload_data_answer), + subscription_server(data_answers=invalid_result_data_answer), + subscription_server(data_answers=invalid_result_keys_data_answer), ], indirect=True, ) -@pytest.mark.parametrize("query_str", [query1_str]) -async def test_phoenix_channel_protocol_error(event_loop, server, query_str): +@pytest.mark.parametrize("query_str", [query2_str]) +async def test_phoenix_channel_subscription_protocol_error( + event_loop, server, query_str +): from gql.transport.phoenix_channel_websockets import ( PhoenixChannelWebsocketsTransport, @@ -183,17 +403,16 @@ async def test_phoenix_channel_protocol_error(event_loop, server, query_str): query = gql(query_str) with pytest.raises(TransportProtocolError): async with Client(transport=sample_transport) as session: - await session.execute(query) + async for _result in session.subscribe(query): + break -server_error_subscription_server_answer = ( - '{"event":"phx_error", "ref":2, "topic":"test_topic"}' -) +server_error_server_answer = '{"event":"phx_error", "ref":2, "topic":"test_topic"}' @pytest.mark.asyncio @pytest.mark.parametrize( - "server", [server(None, server_error_subscription_server_answer)], indirect=True, + "server", [query_server(server_error_server_answer)], indirect=True, ) @pytest.mark.parametrize("query_str", [query1_str]) async def test_phoenix_channel_server_error(event_loop, server, query_str): @@ -212,3 +431,35 @@ async def test_phoenix_channel_server_error(event_loop, server, query_str): with pytest.raises(TransportServerError): async with Client(transport=sample_transport) as session: await session.execute(query) + + +# These cannot be caught by the client +@pytest.mark.asyncio +@pytest.mark.parametrize( + "server", + [ + subscription_server(unsubscribe_answers=invalid_subscription_ref_answer), + subscription_server(unsubscribe_answers=mismatched_unsubscribe_answer), + ], + indirect=True, +) +@pytest.mark.parametrize("query_str", [query2_str]) +async def test_phoenix_channel_unsubscribe_error(event_loop, server, query_str): + + from gql.transport.phoenix_channel_websockets import ( + PhoenixChannelWebsocketsTransport, + ) + + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + + # Reduce close_timeout. These tests will wait for an unsubscribe + # reply that will never come... + sample_transport = PhoenixChannelWebsocketsTransport( + channel_name="test_channel", url=url, close_timeout=1 + ) + + query = gql(query_str) + async with Client(transport=sample_transport) as session: + async for _result in session.subscribe(query): + break diff --git a/tests/test_phoenix_channel_query.py b/tests/test_phoenix_channel_query.py index c3679ac6..b13a8c55 100644 --- a/tests/test_phoenix_channel_query.py +++ b/tests/test_phoenix_channel_query.py @@ -14,6 +14,68 @@ } """ +default_query_server_answer = ( + '{"event":"phx_reply",' + '"payload":' + '{"response":' + '{"data":{"continents":[' + '{"code":"AF","name":"Africa"},{"code":"AN","name":"Antarctica"},' + '{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},' + '{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},' + '{"code":"SA","name":"South America"}]}},' + '"status":"ok"},' + '"ref":2,' + '"topic":"test_topic"}' +) + + +@pytest.fixture +def ws_server_helper(request): + from .conftest import PhoenixChannelServerHelper + + yield PhoenixChannelServerHelper + + +async def query_server(ws, path): + from .conftest import PhoenixChannelServerHelper + + await PhoenixChannelServerHelper.send_connection_ack(ws) + await ws.recv() + await ws.send(default_query_server_answer) + await PhoenixChannelServerHelper.send_close(ws) + await ws.wait_closed() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [query_server], indirect=True) +@pytest.mark.parametrize("query_str", [query1_str]) +async def test_phoenix_channel_query(event_loop, server, query_str): + from gql.transport.phoenix_channel_websockets import ( + PhoenixChannelWebsocketsTransport, + ) + + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + sample_transport = PhoenixChannelWebsocketsTransport( + channel_name="test_channel", url=url + ) + + query = gql(query_str) + async with Client(transport=sample_transport) as session: + result = await session.execute(query) + + print("Client received:", result) + + +query2_str = """ + subscription getContinents { + continents { + code + name + } + } +""" + subscription_server_answer = ( '{"event":"phx_reply",' '"payload":' @@ -24,7 +86,7 @@ '"topic":"test_topic"}' ) -query1_server_answer = ( +subscription_data_server_answer = ( '{"event":"subscription:data","payload":' '{"subscriptionId":"test_subscription","result":' '{"data":{"continents":[' @@ -32,33 +94,39 @@ '{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},' '{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},' '{"code":"SA","name":"South America"}]}}},' + '"ref":null,' + '"topic":"test_subscription"}' +) + +unsubscribe_server_answer = ( + '{"event":"phx_reply",' + '"payload":' + '{"response":' + '{"subscriptionId":"test_subscription"},' + '"status":"ok"},' '"ref":3,' '"topic":"test_topic"}' ) -@pytest.fixture -def ws_server_helper(request): - from .conftest import PhoenixChannelServerHelper - - yield PhoenixChannelServerHelper - - -async def phoenix_server(ws, path): +async def subscription_server(ws, path): from .conftest import PhoenixChannelServerHelper await PhoenixChannelServerHelper.send_connection_ack(ws) await ws.recv() await ws.send(subscription_server_answer) - await ws.send(query1_server_answer) - await PhoenixChannelServerHelper.send_close(ws) + await ws.send(subscription_data_server_answer) + await ws.recv() + await ws.send(unsubscribe_server_answer) + # Unsubscribe will remove the listener + # await PhoenixChannelServerHelper.send_close(ws) await ws.wait_closed() @pytest.mark.asyncio -@pytest.mark.parametrize("server", [phoenix_server], indirect=True) -@pytest.mark.parametrize("query_str", [query1_str]) -async def test_phoenix_channel_simple_query(event_loop, server, query_str): +@pytest.mark.parametrize("server", [subscription_server], indirect=True) +@pytest.mark.parametrize("query_str", [query2_str]) +async def test_phoenix_channel_subscription(event_loop, server, query_str): from gql.transport.phoenix_channel_websockets import ( PhoenixChannelWebsocketsTransport, ) @@ -69,8 +137,11 @@ async def test_phoenix_channel_simple_query(event_loop, server, query_str): channel_name="test_channel", url=url ) + first_result = None query = gql(query_str) async with Client(transport=sample_transport) as session: - result = await session.execute(query) + async for result in session.subscribe(query): + first_result = result + break - print("Client received:", result) + print("Client received:", first_result) From 65c4a0cd7d54cc20b13b9769150fd34db11d86be Mon Sep 17 00:00:00 2001 From: Peter Zingg Date: Mon, 16 Aug 2021 19:07:11 -0700 Subject: [PATCH 19/24] remove reminder comment --- tests/test_phoenix_channel_exceptions.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_phoenix_channel_exceptions.py b/tests/test_phoenix_channel_exceptions.py index fb843c64..41fa459a 100644 --- a/tests/test_phoenix_channel_exceptions.py +++ b/tests/test_phoenix_channel_exceptions.py @@ -10,8 +10,6 @@ # Marking all tests in this file with the websockets marker pytestmark = pytest.mark.websockets -# 210, 219, 234, 239, 272-275, 277, 327 - def ensure_list(s): return ( From 52fbf541b6e53a23f4401d961aeb5518f2e43c1f Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Tue, 17 Aug 2021 11:39:37 +0200 Subject: [PATCH 20/24] Adding failing test --- tests/test_phoenix_channel_subscription.py | 58 ++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/tests/test_phoenix_channel_subscription.py b/tests/test_phoenix_channel_subscription.py index c1f2f8a2..3c6ec2b2 100644 --- a/tests/test_phoenix_channel_subscription.py +++ b/tests/test_phoenix_channel_subscription.py @@ -221,6 +221,64 @@ async def test_phoenix_channel_subscription( assert count == end_count +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_phoenix_channel_subscription_no_break( + event_loop, server, subscription_str +): + import logging + + from gql.transport.phoenix_channel_websockets import ( + PhoenixChannelWebsocketsTransport, + ) + from gql.transport.phoenix_channel_websockets import log as phoenix_logger + from gql.transport.websockets import log as websockets_logger + + websockets_logger.setLevel(logging.DEBUG) + phoenix_logger.setLevel(logging.DEBUG) + + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + + async def testing_stopping_without_break(): + + sample_transport = PhoenixChannelWebsocketsTransport( + channel_name=test_channel, url=url, close_timeout=5 + ) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async with Client(transport=sample_transport) as session: + async for result in session.subscribe(subscription): + number = result["countdown"]["number"] + print(f"Number received: {number}") + + # Simulate a slow consumer + await asyncio.sleep(0.1) + + if number == 9: + # When we consume the number 9 here in the async generator, + # all the 10 numbers have already been sent by the backend and + # are present in the listener queue + # we simulate here an unsubscribe message + # In that case, all the 10 numbers should be consumed in the + # generator and then the generator should be closed properly + await session.transport._send_stop_message(2) + + assert number == count + + count -= 1 + + assert count == -1 + + try: + await asyncio.wait_for(testing_stopping_without_break(), timeout=5) + except asyncio.TimeoutError: + assert False, "The async generator did not stop" + + heartbeat_data_template = ( "{{" '"topic":"{subscription_id}",' From 64b2f74835f4133bbb4b8630e5c382bb0469e32e Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Tue, 17 Aug 2021 12:06:04 +0200 Subject: [PATCH 21/24] Fix failing test by sending 'complete' instead of 'unsubscribe' The async generator will end properly and we will remove the listener inside the subscribe method --- gql/transport/phoenix_channel_websockets.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/gql/transport/phoenix_channel_websockets.py b/gql/transport/phoenix_channel_websockets.py index 5d9c9f8c..601bfddb 100644 --- a/gql/transport/phoenix_channel_websockets.py +++ b/gql/transport/phoenix_channel_websockets.py @@ -187,7 +187,7 @@ def _parse_answer( Returns a list consisting of: - the answer_type (between: - 'heartbeat', 'data', 'reply', 'error', 'unsubscribe') + 'data', 'reply', 'complete', 'close') - the answer id (Integer) if received or None - an execution Result if the answer_type is 'data' or None """ @@ -293,8 +293,6 @@ def _validate_response(d: Any, label: str) -> dict: answer_id ) - # Confirm unsubscribe for subscriptions - self.listeners[answer_id].send_stop = True else: # Query or mutation answer # GraphQL spec says only three keys are permitted @@ -322,7 +320,7 @@ def _validate_response(d: Any, label: str) -> dict: if subscription_id != registered_subscription_id: raise ValueError("subscription id does not match") - answer_type = "unsubscribe" + answer_type = "complete" answer_id = listener_id @@ -372,10 +370,6 @@ async def _handle_answer( ) -> None: if answer_type == "close": await self.close() - elif answer_type == "unsubscribe": - # Remove the listener here, to possibly signal - # that it is the last listener in the session. - self._remove_listener(answer_id) else: await super()._handle_answer(answer_type, answer_id, execution_result) From a2a22b36b734efc4110cfd19707e3fec68302927 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Tue, 17 Aug 2021 14:59:40 +0200 Subject: [PATCH 22/24] Fix test coverage --- gql/transport/phoenix_channel_websockets.py | 19 +++----- tests/test_phoenix_channel_exceptions.py | 54 +++++++++++++++++++++ 2 files changed, 61 insertions(+), 12 deletions(-) diff --git a/gql/transport/phoenix_channel_websockets.py b/gql/transport/phoenix_channel_websockets.py index 601bfddb..91fdc76a 100644 --- a/gql/transport/phoenix_channel_websockets.py +++ b/gql/transport/phoenix_channel_websockets.py @@ -263,20 +263,14 @@ def _validate_response(d: Any, label: str) -> dict: ) elif event == "phx_reply": - try: - answer_id = int(_required_value(json_answer, "ref", "answer")) - except ValueError: # pragma: no cover - raise - except Exception: # pragma: no cover - raise ValueError("ref is not an integer") - if answer_id == 0: # pragma: no cover - raise ValueError("ref is zero") + + # Will generate a ValueError if 'ref' is not there + # or if it is not an integer + answer_id = int(_required_value(json_answer, "ref", "answer")) payload = _required_value(json_answer, "payload", "answer") status = _get_value(payload, "status", "payload") - if status is not None: - status = str(status) if status == "ok": answer_type = "reply" @@ -317,6 +311,7 @@ def _validate_response(d: Any, label: str) -> dict: subscription_id = _required_subscription_id( response, "response" ) + if subscription_id != registered_subscription_id: raise ValueError("subscription id does not match") @@ -357,7 +352,7 @@ def _validate_response(d: Any, label: str) -> dict: except ValueError as e: log.error(f"Error parsing answer '{answer}': {e!r}") raise TransportProtocolError( - "Server did not return a GraphQL result: " + str(e) + f"Server did not return a GraphQL result: {e!s}" ) from e return answer_type, answer_id, execution_result @@ -402,7 +397,7 @@ def _find_existing_subscription(self, query_id: int) -> str: if subscription_id is None: raise TransportProtocolError( f"No subscription registered for listener {query_id}" - ) # pragma: no cover + ) return subscription_id async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: diff --git a/tests/test_phoenix_channel_exceptions.py b/tests/test_phoenix_channel_exceptions.py index 41fa459a..1711d25a 100644 --- a/tests/test_phoenix_channel_exceptions.py +++ b/tests/test_phoenix_channel_exceptions.py @@ -1,3 +1,5 @@ +import asyncio + import pytest from gql import Client, gql @@ -7,6 +9,8 @@ TransportServerError, ) +from .conftest import MS + # Marking all tests in this file with the websockets marker pytestmark = pytest.mark.websockets @@ -223,6 +227,23 @@ async def test_phoenix_channel_query_error(event_loop, server, query_str): '"topic":"test_topic"}' ) +ref_is_not_an_integer_server_answer = ( + '{"event":"phx_reply",' + '"payload":' + '{"response":{"subscriptionId":"test_subscription"},' + '"status":"ok"},' + '"ref":"not_an_integer",' + '"topic":"test_topic"}' +) + +missing_ref_server_answer = ( + '{"event":"phx_reply",' + '"payload":' + '{"response":{"subscriptionId":"test_subscription"},' + '"status":"ok"},' + '"topic":"test_topic"}' +) + missing_subscription_id_server_answer = ( '{"event":"phx_reply",' '"payload":' @@ -377,6 +398,8 @@ async def phoenix_server(ws, path): subscription_server(data_answers=missing_subscription_id_data_answer), subscription_server(data_answers=null_subscription_id_data_answer), subscription_server(data_answers=invalid_subscription_id_data_answer), + subscription_server(data_answers=ref_is_not_an_integer_server_answer), + subscription_server(data_answers=missing_ref_server_answer), subscription_server(data_answers=invalid_payload_data_answer), subscription_server(data_answers=invalid_result_data_answer), subscription_server(data_answers=invalid_result_keys_data_answer), @@ -402,6 +425,7 @@ async def test_phoenix_channel_subscription_protocol_error( with pytest.raises(TransportProtocolError): async with Client(transport=sample_transport) as session: async for _result in session.subscribe(query): + await asyncio.sleep(10 * MS) break @@ -461,3 +485,33 @@ async def test_phoenix_channel_unsubscribe_error(event_loop, server, query_str): async with Client(transport=sample_transport) as session: async for _result in session.subscribe(query): break + + +# We can force the error if somehow the generator is still running while +# we receive a mismatched unsubscribe answer +@pytest.mark.asyncio +@pytest.mark.parametrize( + "server", + [subscription_server(unsubscribe_answers=mismatched_unsubscribe_answer)], + indirect=True, +) +@pytest.mark.parametrize("query_str", [query2_str]) +async def test_phoenix_channel_unsubscribe_error_forcing(event_loop, server, query_str): + + from gql.transport.phoenix_channel_websockets import ( + PhoenixChannelWebsocketsTransport, + ) + + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + + sample_transport = PhoenixChannelWebsocketsTransport( + channel_name="test_channel", url=url, close_timeout=1 + ) + + query = gql(query_str) + with pytest.raises(TransportProtocolError): + async with Client(transport=sample_transport) as session: + async for _result in session.subscribe(query): + await session.transport._send_stop_message(2) + await asyncio.sleep(10 * MS) From b740d58bc306fbcea8aa87ec22e97ac1003626c0 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Tue, 17 Aug 2021 15:06:26 +0200 Subject: [PATCH 23/24] remove TransportClosed except in subscribe --- gql/transport/phoenix_channel_websockets.py | 2 +- gql/transport/websockets.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/gql/transport/phoenix_channel_websockets.py b/gql/transport/phoenix_channel_websockets.py index 91fdc76a..4b59fda9 100644 --- a/gql/transport/phoenix_channel_websockets.py +++ b/gql/transport/phoenix_channel_websockets.py @@ -368,7 +368,7 @@ async def _handle_answer( else: await super()._handle_answer(answer_type, answer_id, execution_result) - def _remove_listener(self, query_id) -> None: + def _remove_listener(self, query_id: int) -> None: """If the listener was a subscription, remove that information.""" try: subscription_id = self._find_existing_subscription(query_id) diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index a0086472..27c3965e 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -502,7 +502,7 @@ async def subscribe( ) break - except (asyncio.CancelledError, GeneratorExit, TransportClosed) as e: + except (asyncio.CancelledError, GeneratorExit) as e: log.debug(f"Exception in subscribe: {e!r}") if listener.send_stop: await self._send_stop_message(query_id) From 9fbaa829a9d3e44fb8565e08aa40a62a7cd18086 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Tue, 17 Aug 2021 17:38:36 +0200 Subject: [PATCH 24/24] don't try to get the operation_type from the query in the transport --- gql/transport/phoenix_channel_websockets.py | 17 ++++++++------- gql/transport/websockets.py | 23 +++------------------ 2 files changed, 12 insertions(+), 28 deletions(-) diff --git a/gql/transport/phoenix_channel_websockets.py b/gql/transport/phoenix_channel_websockets.py index 4b59fda9..56d35f8b 100644 --- a/gql/transport/phoenix_channel_websockets.py +++ b/gql/transport/phoenix_channel_websockets.py @@ -222,8 +222,8 @@ def _required_subscription_id( return subscription_id - def _validate_response(d: Any, label: str) -> dict: - """Make sure query or mutation answer conforms. + def _validate_data_response(d: Any, label: str) -> dict: + """Make sure query, mutation or subscription answer conforms. The GraphQL spec says only three keys are permitted. """ if not isinstance(d, dict): @@ -249,7 +249,7 @@ def _validate_response(d: Any, label: str) -> dict: payload, "payload", must_exist=True ) - result = _validate_response(payload.get("result"), "result") + result = _validate_data_response(payload.get("result"), "result") answer_type = "data" @@ -276,13 +276,15 @@ def _validate_response(d: Any, label: str) -> dict: answer_type = "reply" if answer_id in self.listeners: - operation_type = self.listeners[answer_id].operation - if operation_type == "subscription": + response = _required_value(payload, "response", "payload") + + if isinstance(response, dict) and "subscriptionId" in response: + # Subscription answer - response = _required_value(payload, "response", "payload") subscription_id = _required_subscription_id( response, "response", must_not_exist=True ) + self.subscriptions[subscription_id] = Subscription( answer_id ) @@ -290,8 +292,7 @@ def _validate_response(d: Any, label: str) -> dict: else: # Query or mutation answer # GraphQL spec says only three keys are permitted - response = _required_value(payload, "response", "payload") - response = _validate_response(response, "response") + response = _validate_data_response(response, "response") answer_type = "data" diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index 27c3965e..50eeb6b0 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -6,8 +6,7 @@ from typing import Any, AsyncGenerator, Dict, Optional, Tuple, Union, cast import websockets -from graphql import DocumentNode, ExecutionResult, OperationType, print_ast -from graphql.utilities import get_operation_ast +from graphql import DocumentNode, ExecutionResult, print_ast from websockets.client import WebSocketClientProtocol from websockets.datastructures import HeadersLike from websockets.exceptions import ConnectionClosed @@ -35,17 +34,8 @@ class ListenerQueue: to the consumer once all the previous messages have been consumed from the queue """ - def __init__( - self, query_id: int, operation_type: Optional[OperationType], send_stop: bool - ) -> None: - operation: str = "query" - if operation_type is not None: - if operation_type == OperationType.MUTATION: # pragma: no cover - operation = "mutation" - elif operation_type == OperationType.SUBSCRIPTION: - operation = "subscription" + def __init__(self, query_id: int, send_stop: bool) -> None: self.query_id: int = query_id - self.operation: str = operation self.send_stop: bool = send_stop self._queue: asyncio.Queue = asyncio.Queue() self._closed: bool = False @@ -468,14 +458,7 @@ async def subscribe( ) # Create a queue to receive the answers for this query_id - operation_type = None - operation = get_operation_ast(document, operation_name) - if operation is not None: - operation_type = operation.operation - - listener = ListenerQueue( - query_id, operation_type, send_stop=(send_stop is True) - ) + listener = ListenerQueue(query_id, send_stop=(send_stop is True)) self.listeners[query_id] = listener # We will need to wait at close for this query to clean properly