Skip to content

Handle Absinthe unsubscriptions #228

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
c7d02cd
Initialize heartbeat_task correctly
leszekhanusz Aug 10, 2021
fdb8307
Add debug logs to phoenix subscription test
leszekhanusz Aug 10, 2021
f9e74e9
Fix assertion error when transport is already closed when _receive is…
leszekhanusz Aug 10, 2021
1e847d5
Stop websocket subscriptions on TransportClosed
pzingg Aug 13, 2021
caa671d
Add logger to PhoenixChannelWebsocketsTransport
pzingg Aug 13, 2021
3b3cec7
Implement Absinthe unsubscribe protocol in PhoenixChannelWebsocketsTr…
pzingg Aug 13, 2021
4ec38ff
Add unsubscribe to Phoenix Channel subscription tests
pzingg Aug 13, 2021
018ab64
Fix typo in unused template
pzingg Aug 13, 2021
34c3eb4
Add generator aclose in Phoenix Channel tests for Python<3.7
pzingg Aug 13, 2021
01f4a45
Fix phoenix parse_answer and subscriptions
pzingg Aug 16, 2021
fd0b32f
parse responses per graphql spec
pzingg Aug 16, 2021
ca242a9
Refactor unsubscription reply
leszekhanusz Aug 16, 2021
7c2f85e
Note the operation type for each listener
pzingg Aug 16, 2021
f35de03
Make phoenix behavior more similar to the websockets code
leszekhanusz Aug 16, 2021
7f685bc
Modifying exception to error log
leszekhanusz Aug 16, 2021
1f991d6
Faster phoenix heartbeat tests
leszekhanusz Aug 16, 2021
c57019a
Make parse_answer more compliant with specs, better subscription mana…
pzingg Aug 16, 2021
e3788bc
More testing, distinguish tests between queries and subscriptions
pzingg Aug 16, 2021
e922865
Merge branch 'master' into fix_phoenix_subscriptions_close
leszekhanusz Aug 16, 2021
e0f009f
merge from e922865
pzingg Aug 16, 2021
8a65b42
Merge branch 'fix_phoenix_subscriptions_close' of github.com:pzingg/g…
pzingg Aug 16, 2021
65c4a0c
remove reminder comment
pzingg Aug 17, 2021
52fbf54
Adding failing test
leszekhanusz Aug 17, 2021
64b2f74
Fix failing test by sending 'complete' instead of 'unsubscribe'
leszekhanusz Aug 17, 2021
a2a22b3
Fix test coverage
leszekhanusz Aug 17, 2021
b740d58
remove TransportClosed except in subscribe
leszekhanusz Aug 17, 2021
9fbaa82
don't try to get the operation_type from the query in the transport
leszekhanusz Aug 17, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
233 changes: 190 additions & 43 deletions gql/transport/phoenix_channel_websockets.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import json
import logging
from typing import Any, Dict, Optional, Tuple

from graphql import DocumentNode, ExecutionResult, print_ast
Expand All @@ -12,6 +13,16 @@
)
from .websockets import WebsocketsTransport

log = logging.getLogger(__name__)


class Subscription:
"""Records listener_id and unsubscribe query_id for a subscription."""

def __init__(self, query_id: int) -> None:
self.listener_id: int = query_id
self.unsubscribe_id: Optional[int] = None


class PhoenixChannelWebsocketsTransport(WebsocketsTransport):
"""The PhoenixChannelWebsocketsTransport is an **EXPERIMENTAL** async transport
Expand All @@ -24,17 +35,23 @@ class PhoenixChannelWebsocketsTransport(WebsocketsTransport):
"""

def __init__(
self, channel_name: str, heartbeat_interval: float = 30, *args, **kwargs
self,
channel_name: str = "__absinthe__:control",
heartbeat_interval: float = 30,
*args,
**kwargs,
) -> None:
"""Initialize the transport with the given parameters.

:param channel_name: Channel on the server this transport will join
:param channel_name: Channel on the server this transport will join.
The default for Absinthe servers is "__absinthe__:control"
:param heartbeat_interval: Interval in second between each heartbeat messages
sent by the client
"""
self.channel_name = channel_name
self.heartbeat_interval = heartbeat_interval
self.subscription_ids_to_query_ids: Dict[str, int] = {}
self.channel_name: str = channel_name
self.heartbeat_interval: float = heartbeat_interval
self.heartbeat_task: Optional[asyncio.Future] = None
self.subscriptions: Dict[str, Subscription] = {}
super(PhoenixChannelWebsocketsTransport, self).__init__(*args, **kwargs)

async def _send_init_message_and_wait_ack(self) -> None:
Expand Down Expand Up @@ -90,14 +107,32 @@ async def heartbeat_coro():
self.heartbeat_task = asyncio.ensure_future(heartbeat_coro())

async def _send_stop_message(self, query_id: int) -> None:
try:
await self.listeners[query_id].put(("complete", None))
except KeyError: # pragma: no cover
pass
"""Send an 'unsubscribe' message to the Phoenix Channel referencing
the listener's query_id, saving the query_id of the message.

async def _send_connection_terminate_message(self) -> None:
"""Send a phx_leave message to disconnect from the provided channel.
The server should afterwards return a 'phx_reply' message with
the same query_id and subscription_id of the 'unsubscribe' request.
"""
subscription_id = self._find_existing_subscription(query_id)

unsubscribe_query_id = self.next_query_id
self.next_query_id += 1

# Save the ref so it can be matched in the reply
self.subscriptions[subscription_id].unsubscribe_id = unsubscribe_query_id
unsubscribe_message = json.dumps(
{
"topic": self.channel_name,
"event": "unsubscribe",
"payload": {"subscriptionId": subscription_id},
"ref": unsubscribe_query_id,
}
)

await self._send(unsubscribe_message)

async def _send_connection_terminate_message(self) -> None:
"""Send a phx_leave message to disconnect from the provided channel."""

query_id = self.next_query_id
self.next_query_id += 1
Expand Down Expand Up @@ -152,7 +187,7 @@ def _parse_answer(

Returns a list consisting of:
- the answer_type (between:
'heartbeat', 'data', 'reply', 'error', 'close')
'data', 'reply', 'complete', 'close')
- the answer id (Integer) if received or None
- an execution Result if the answer_type is 'data' or None
"""
Expand All @@ -161,56 +196,129 @@ def _parse_answer(
answer_id: Optional[int] = None
answer_type: str = ""
execution_result: Optional[ExecutionResult] = None
subscription_id: Optional[str] = None

def _get_value(d: Any, key: str, label: str) -> Any:
if not isinstance(d, dict):
raise ValueError(f"{label} is not a dict")

return d.get(key)

def _required_value(d: Any, key: str, label: str) -> Any:
value = _get_value(d, key, label)
if value is None:
raise ValueError(f"null {key} in {label}")

return value

def _required_subscription_id(
d: Any, label: str, must_exist: bool = False, must_not_exist=False
) -> str:
subscription_id = str(_required_value(d, "subscriptionId", label))
if must_exist and (subscription_id not in self.subscriptions):
raise ValueError("unregistered subscriptionId")
if must_not_exist and (subscription_id in self.subscriptions):
raise ValueError("previously registered subscriptionId")

return subscription_id

def _validate_data_response(d: Any, label: str) -> dict:
"""Make sure query, mutation or subscription answer conforms.
The GraphQL spec says only three keys are permitted.
"""
if not isinstance(d, dict):
raise ValueError(f"{label} is not a dict")

keys = set(d.keys())
invalid = keys - {"data", "errors", "extensions"}
if len(invalid) > 0:
raise ValueError(
f"{label} contains invalid items: " + ", ".join(invalid)
)
return d

try:
json_answer = json.loads(answer)

event = str(json_answer.get("event"))
event = str(_required_value(json_answer, "event", "answer"))

if event == "subscription:data":
payload = json_answer.get("payload")
payload = _required_value(json_answer, "payload", "answer")

if not isinstance(payload, dict):
raise ValueError("payload is not a dict")

subscription_id = str(payload.get("subscriptionId"))
try:
answer_id = self.subscription_ids_to_query_ids[subscription_id]
except KeyError:
raise ValueError(
f"subscription '{subscription_id}' has not been registerd"
)

result = payload.get("result")
subscription_id = _required_subscription_id(
payload, "payload", must_exist=True
)

if not isinstance(result, dict):
raise ValueError("result is not a dict")
result = _validate_data_response(payload.get("result"), "result")

answer_type = "data"

subscription = self.subscriptions[subscription_id]
answer_id = subscription.listener_id

execution_result = ExecutionResult(
errors=payload.get("errors"),
data=result.get("data"),
extensions=payload.get("extensions"),
errors=result.get("errors"),
extensions=result.get("extensions"),
)

elif event == "phx_reply":
answer_id = int(json_answer.get("ref"))
payload = json_answer.get("payload")

if not isinstance(payload, dict):
raise ValueError("payload is not a dict")
# Will generate a ValueError if 'ref' is not there
# or if it is not an integer
answer_id = int(_required_value(json_answer, "ref", "answer"))

status = str(payload.get("status"))
payload = _required_value(json_answer, "payload", "answer")

if status == "ok":
status = _get_value(payload, "status", "payload")

if status == "ok":
answer_type = "reply"
response = payload.get("response")

if isinstance(response, dict) and "subscriptionId" in response:
subscription_id = str(response.get("subscriptionId"))
self.subscription_ids_to_query_ids[subscription_id] = answer_id
if answer_id in self.listeners:
response = _required_value(payload, "response", "payload")

if isinstance(response, dict) and "subscriptionId" in response:

# Subscription answer
subscription_id = _required_subscription_id(
response, "response", must_not_exist=True
)

self.subscriptions[subscription_id] = Subscription(
answer_id
)

else:
# Query or mutation answer
# GraphQL spec says only three keys are permitted
response = _validate_data_response(response, "response")

answer_type = "data"

execution_result = ExecutionResult(
data=response.get("data"),
errors=response.get("errors"),
extensions=response.get("extensions"),
)
else:
(
registered_subscription_id,
listener_id,
) = self._find_subscription(answer_id)
if registered_subscription_id is not None:
# Unsubscription answer
response = _required_value(payload, "response", "payload")
subscription_id = _required_subscription_id(
response, "response"
)

if subscription_id != registered_subscription_id:
raise ValueError("subscription id does not match")

answer_type = "complete"

answer_id = listener_id

elif status == "error":
response = payload.get("response")
Expand All @@ -224,21 +332,28 @@ def _parse_answer(
raise TransportQueryError(
str(response.get("reason")), query_id=answer_id
)
raise ValueError("reply error")
raise TransportQueryError("reply error", query_id=answer_id)

elif status == "timeout":
raise TransportQueryError("reply timeout", query_id=answer_id)
else:
# missing or unrecognized status, just continue
pass

elif event == "phx_error":
# Sent if the channel has crashed
# answer_id will be the "join_ref" for the channel
# answer_id = int(json_answer.get("ref"))
raise TransportServerError("Server error")
elif event == "phx_close":
answer_type = "close"
else:
raise ValueError
raise ValueError("unrecognized event")

except ValueError as e:
log.error(f"Error parsing answer '{answer}': {e!r}")
raise TransportProtocolError(
"Server did not return a GraphQL result"
f"Server did not return a GraphQL result: {e!s}"
) from e

return answer_type, answer_id, execution_result
Expand All @@ -254,6 +369,38 @@ async def _handle_answer(
else:
await super()._handle_answer(answer_type, answer_id, execution_result)

def _remove_listener(self, query_id: int) -> None:
"""If the listener was a subscription, remove that information."""
try:
subscription_id = self._find_existing_subscription(query_id)
del self.subscriptions[subscription_id]
except Exception:
pass
super()._remove_listener(query_id)

def _find_subscription(self, query_id: int) -> Tuple[Optional[str], int]:
"""Perform a reverse lookup to find the subscription id matching
a listener's query_id.
"""
for subscription_id, subscription in self.subscriptions.items():
if query_id == subscription.listener_id:
return subscription_id, query_id
if query_id == subscription.unsubscribe_id:
return subscription_id, subscription.listener_id
return None, query_id

def _find_existing_subscription(self, query_id: int) -> str:
"""Perform a reverse lookup to find the subscription id matching
a listener's query_id.
"""
subscription_id, _listener_id = self._find_subscription(query_id)

if subscription_id is None:
raise TransportProtocolError(
f"No subscription registered for listener {query_id}"
)
return subscription_id

async def _close_coro(self, e: Exception, clean_close: bool = True) -> None:
if self.heartbeat_task is not None:
self.heartbeat_task.cancel()
Expand Down
29 changes: 22 additions & 7 deletions gql/transport/websockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,9 @@ async def _receive(self) -> str:
"""Wait the next message from the websocket connection and log the answer
"""

# We should always have an active websocket connection here
assert self.websocket is not None
# It is possible that the websocket has been already closed in another task
if self.websocket is None:
raise TransportClosed("Transport is already closed")

# Wait for the next websocket frame. Can raise ConnectionClosed
data: Data = await self.websocket.recv()
Expand Down Expand Up @@ -387,6 +388,8 @@ async def _receive_data_loop(self) -> None:
except (ConnectionClosed, TransportProtocolError) as e:
await self._fail(e, clean_close=False)
break
except TransportClosed:
break

# Parse the answer
try:
Expand Down Expand Up @@ -483,15 +486,14 @@ async def subscribe(
break

except (asyncio.CancelledError, GeneratorExit) as e:
log.debug("Exception in subscribe: " + repr(e))
log.debug(f"Exception in subscribe: {e!r}")
if listener.send_stop:
await self._send_stop_message(query_id)
listener.send_stop = False

finally:
del self.listeners[query_id]
if len(self.listeners) == 0:
self._no_more_listeners.set()
log.debug(f"In subscribe finally for query_id {query_id}")
self._remove_listener(query_id)

async def execute(
self,
Expand Down Expand Up @@ -609,6 +611,19 @@ async def connect(self) -> None:

log.debug("connect: done")

def _remove_listener(self, query_id) -> None:
"""After exiting from a subscription, remove the listener and
signal an event if this was the last listener for the client.
"""
if query_id in self.listeners:
del self.listeners[query_id]

remaining = len(self.listeners)
log.debug(f"listener {query_id} deleted, {remaining} remaining")

if remaining == 0:
self._no_more_listeners.set()

async def _clean_close(self, e: Exception) -> None:
"""Coroutine which will:

Expand All @@ -627,7 +642,7 @@ async def _clean_close(self, e: Exception) -> None:
try:
await asyncio.wait_for(self._no_more_listeners.wait(), self.close_timeout)
except asyncio.TimeoutError: # pragma: no cover
pass
log.debug("Timer close_timeout fired")

# Finally send the 'connection_terminate' message
await self._send_connection_terminate_message()
Expand Down
Loading