diff --git a/gql/transport/phoenix_channel_websockets.py b/gql/transport/phoenix_channel_websockets.py index 27e58f2a..26ab159b 100644 --- a/gql/transport/phoenix_channel_websockets.py +++ b/gql/transport/phoenix_channel_websockets.py @@ -34,6 +34,7 @@ def __init__( """ self.channel_name = channel_name self.heartbeat_interval = heartbeat_interval + self.heartbeat_task: Optional[asyncio.Future] = None self.subscription_ids_to_query_ids: Dict[str, int] = {} super(PhoenixChannelWebsocketsTransport, self).__init__(*args, **kwargs) diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index 7e26f31c..2dce0545 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -175,8 +175,9 @@ async def _receive(self) -> str: """Wait the next message from the websocket connection and log the answer """ - # We should always have an active websocket connection here - assert self.websocket is not None + # It is possible that the websocket has been already closed in another task + if self.websocket is None: + raise TransportClosed("Transport is already closed") # Wait for the next websocket frame. Can raise ConnectionClosed data: Data = await self.websocket.recv() @@ -387,6 +388,8 @@ async def _receive_data_loop(self) -> None: except (ConnectionClosed, TransportProtocolError) as e: await self._fail(e, clean_close=False) break + except TransportClosed: + break # Parse the answer try: diff --git a/tests/test_phoenix_channel_subscription.py b/tests/test_phoenix_channel_subscription.py index ef46db47..16ce5c34 100644 --- a/tests/test_phoenix_channel_subscription.py +++ b/tests/test_phoenix_channel_subscription.py @@ -101,9 +101,13 @@ async def stopping_coro(): @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_phoenix_channel_subscription(event_loop, server, subscription_str): + import logging from gql.transport.phoenix_channel_websockets import ( PhoenixChannelWebsocketsTransport, ) + from gql.transport.websockets import log as websockets_logger + + websockets_logger.setLevel(logging.DEBUG) path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}"