Skip to content

Commit 20ae2e2

Browse files
authored
Handle Absinthe unsubscriptions (#228)
1 parent 66174a6 commit 20ae2e2

6 files changed

+918
-165
lines changed

gql/transport/phoenix_channel_websockets.py

Lines changed: 190 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import json
3+
import logging
34
from typing import Any, Dict, Optional, Tuple
45

56
from graphql import DocumentNode, ExecutionResult, print_ast
@@ -12,6 +13,16 @@
1213
)
1314
from .websockets import WebsocketsTransport
1415

16+
log = logging.getLogger(__name__)
17+
18+
19+
class Subscription:
20+
"""Records listener_id and unsubscribe query_id for a subscription."""
21+
22+
def __init__(self, query_id: int) -> None:
23+
self.listener_id: int = query_id
24+
self.unsubscribe_id: Optional[int] = None
25+
1526

1627
class PhoenixChannelWebsocketsTransport(WebsocketsTransport):
1728
"""The PhoenixChannelWebsocketsTransport is an **EXPERIMENTAL** async transport
@@ -24,17 +35,23 @@ class PhoenixChannelWebsocketsTransport(WebsocketsTransport):
2435
"""
2536

2637
def __init__(
27-
self, channel_name: str, heartbeat_interval: float = 30, *args, **kwargs
38+
self,
39+
channel_name: str = "__absinthe__:control",
40+
heartbeat_interval: float = 30,
41+
*args,
42+
**kwargs,
2843
) -> None:
2944
"""Initialize the transport with the given parameters.
3045
31-
:param channel_name: Channel on the server this transport will join
46+
:param channel_name: Channel on the server this transport will join.
47+
The default for Absinthe servers is "__absinthe__:control"
3248
:param heartbeat_interval: Interval in second between each heartbeat messages
3349
sent by the client
3450
"""
35-
self.channel_name = channel_name
36-
self.heartbeat_interval = heartbeat_interval
37-
self.subscription_ids_to_query_ids: Dict[str, int] = {}
51+
self.channel_name: str = channel_name
52+
self.heartbeat_interval: float = heartbeat_interval
53+
self.heartbeat_task: Optional[asyncio.Future] = None
54+
self.subscriptions: Dict[str, Subscription] = {}
3855
super(PhoenixChannelWebsocketsTransport, self).__init__(*args, **kwargs)
3956

4057
async def _send_init_message_and_wait_ack(self) -> None:
@@ -90,14 +107,32 @@ async def heartbeat_coro():
90107
self.heartbeat_task = asyncio.ensure_future(heartbeat_coro())
91108

92109
async def _send_stop_message(self, query_id: int) -> None:
93-
try:
94-
await self.listeners[query_id].put(("complete", None))
95-
except KeyError: # pragma: no cover
96-
pass
110+
"""Send an 'unsubscribe' message to the Phoenix Channel referencing
111+
the listener's query_id, saving the query_id of the message.
97112
98-
async def _send_connection_terminate_message(self) -> None:
99-
"""Send a phx_leave message to disconnect from the provided channel.
113+
The server should afterwards return a 'phx_reply' message with
114+
the same query_id and subscription_id of the 'unsubscribe' request.
100115
"""
116+
subscription_id = self._find_existing_subscription(query_id)
117+
118+
unsubscribe_query_id = self.next_query_id
119+
self.next_query_id += 1
120+
121+
# Save the ref so it can be matched in the reply
122+
self.subscriptions[subscription_id].unsubscribe_id = unsubscribe_query_id
123+
unsubscribe_message = json.dumps(
124+
{
125+
"topic": self.channel_name,
126+
"event": "unsubscribe",
127+
"payload": {"subscriptionId": subscription_id},
128+
"ref": unsubscribe_query_id,
129+
}
130+
)
131+
132+
await self._send(unsubscribe_message)
133+
134+
async def _send_connection_terminate_message(self) -> None:
135+
"""Send a phx_leave message to disconnect from the provided channel."""
101136

102137
query_id = self.next_query_id
103138
self.next_query_id += 1
@@ -152,7 +187,7 @@ def _parse_answer(
152187
153188
Returns a list consisting of:
154189
- the answer_type (between:
155-
'heartbeat', 'data', 'reply', 'error', 'close')
190+
'data', 'reply', 'complete', 'close')
156191
- the answer id (Integer) if received or None
157192
- an execution Result if the answer_type is 'data' or None
158193
"""
@@ -161,56 +196,129 @@ def _parse_answer(
161196
answer_id: Optional[int] = None
162197
answer_type: str = ""
163198
execution_result: Optional[ExecutionResult] = None
199+
subscription_id: Optional[str] = None
200+
201+
def _get_value(d: Any, key: str, label: str) -> Any:
202+
if not isinstance(d, dict):
203+
raise ValueError(f"{label} is not a dict")
204+
205+
return d.get(key)
206+
207+
def _required_value(d: Any, key: str, label: str) -> Any:
208+
value = _get_value(d, key, label)
209+
if value is None:
210+
raise ValueError(f"null {key} in {label}")
211+
212+
return value
213+
214+
def _required_subscription_id(
215+
d: Any, label: str, must_exist: bool = False, must_not_exist=False
216+
) -> str:
217+
subscription_id = str(_required_value(d, "subscriptionId", label))
218+
if must_exist and (subscription_id not in self.subscriptions):
219+
raise ValueError("unregistered subscriptionId")
220+
if must_not_exist and (subscription_id in self.subscriptions):
221+
raise ValueError("previously registered subscriptionId")
222+
223+
return subscription_id
224+
225+
def _validate_data_response(d: Any, label: str) -> dict:
226+
"""Make sure query, mutation or subscription answer conforms.
227+
The GraphQL spec says only three keys are permitted.
228+
"""
229+
if not isinstance(d, dict):
230+
raise ValueError(f"{label} is not a dict")
231+
232+
keys = set(d.keys())
233+
invalid = keys - {"data", "errors", "extensions"}
234+
if len(invalid) > 0:
235+
raise ValueError(
236+
f"{label} contains invalid items: " + ", ".join(invalid)
237+
)
238+
return d
164239

165240
try:
166241
json_answer = json.loads(answer)
167242

168-
event = str(json_answer.get("event"))
243+
event = str(_required_value(json_answer, "event", "answer"))
169244

170245
if event == "subscription:data":
171-
payload = json_answer.get("payload")
246+
payload = _required_value(json_answer, "payload", "answer")
172247

173-
if not isinstance(payload, dict):
174-
raise ValueError("payload is not a dict")
175-
176-
subscription_id = str(payload.get("subscriptionId"))
177-
try:
178-
answer_id = self.subscription_ids_to_query_ids[subscription_id]
179-
except KeyError:
180-
raise ValueError(
181-
f"subscription '{subscription_id}' has not been registerd"
182-
)
183-
184-
result = payload.get("result")
248+
subscription_id = _required_subscription_id(
249+
payload, "payload", must_exist=True
250+
)
185251

186-
if not isinstance(result, dict):
187-
raise ValueError("result is not a dict")
252+
result = _validate_data_response(payload.get("result"), "result")
188253

189254
answer_type = "data"
190255

256+
subscription = self.subscriptions[subscription_id]
257+
answer_id = subscription.listener_id
258+
191259
execution_result = ExecutionResult(
192-
errors=payload.get("errors"),
193260
data=result.get("data"),
194-
extensions=payload.get("extensions"),
261+
errors=result.get("errors"),
262+
extensions=result.get("extensions"),
195263
)
196264

197265
elif event == "phx_reply":
198-
answer_id = int(json_answer.get("ref"))
199-
payload = json_answer.get("payload")
200266

201-
if not isinstance(payload, dict):
202-
raise ValueError("payload is not a dict")
267+
# Will generate a ValueError if 'ref' is not there
268+
# or if it is not an integer
269+
answer_id = int(_required_value(json_answer, "ref", "answer"))
203270

204-
status = str(payload.get("status"))
271+
payload = _required_value(json_answer, "payload", "answer")
205272

206-
if status == "ok":
273+
status = _get_value(payload, "status", "payload")
207274

275+
if status == "ok":
208276
answer_type = "reply"
209-
response = payload.get("response")
210277

211-
if isinstance(response, dict) and "subscriptionId" in response:
212-
subscription_id = str(response.get("subscriptionId"))
213-
self.subscription_ids_to_query_ids[subscription_id] = answer_id
278+
if answer_id in self.listeners:
279+
response = _required_value(payload, "response", "payload")
280+
281+
if isinstance(response, dict) and "subscriptionId" in response:
282+
283+
# Subscription answer
284+
subscription_id = _required_subscription_id(
285+
response, "response", must_not_exist=True
286+
)
287+
288+
self.subscriptions[subscription_id] = Subscription(
289+
answer_id
290+
)
291+
292+
else:
293+
# Query or mutation answer
294+
# GraphQL spec says only three keys are permitted
295+
response = _validate_data_response(response, "response")
296+
297+
answer_type = "data"
298+
299+
execution_result = ExecutionResult(
300+
data=response.get("data"),
301+
errors=response.get("errors"),
302+
extensions=response.get("extensions"),
303+
)
304+
else:
305+
(
306+
registered_subscription_id,
307+
listener_id,
308+
) = self._find_subscription(answer_id)
309+
if registered_subscription_id is not None:
310+
# Unsubscription answer
311+
response = _required_value(payload, "response", "payload")
312+
subscription_id = _required_subscription_id(
313+
response, "response"
314+
)
315+
316+
if subscription_id != registered_subscription_id:
317+
raise ValueError("subscription id does not match")
318+
319+
answer_type = "complete"
320+
321+
answer_id = listener_id
214322

215323
elif status == "error":
216324
response = payload.get("response")
@@ -224,21 +332,28 @@ def _parse_answer(
224332
raise TransportQueryError(
225333
str(response.get("reason")), query_id=answer_id
226334
)
227-
raise ValueError("reply error")
335+
raise TransportQueryError("reply error", query_id=answer_id)
228336

229337
elif status == "timeout":
230338
raise TransportQueryError("reply timeout", query_id=answer_id)
339+
else:
340+
# missing or unrecognized status, just continue
341+
pass
231342

232343
elif event == "phx_error":
344+
# Sent if the channel has crashed
345+
# answer_id will be the "join_ref" for the channel
346+
# answer_id = int(json_answer.get("ref"))
233347
raise TransportServerError("Server error")
234348
elif event == "phx_close":
235349
answer_type = "close"
236350
else:
237-
raise ValueError
351+
raise ValueError("unrecognized event")
238352

239353
except ValueError as e:
354+
log.error(f"Error parsing answer '{answer}': {e!r}")
240355
raise TransportProtocolError(
241-
"Server did not return a GraphQL result"
356+
f"Server did not return a GraphQL result: {e!s}"
242357
) from e
243358

244359
return answer_type, answer_id, execution_result
@@ -254,6 +369,38 @@ async def _handle_answer(
254369
else:
255370
await super()._handle_answer(answer_type, answer_id, execution_result)
256371

372+
def _remove_listener(self, query_id: int) -> None:
373+
"""If the listener was a subscription, remove that information."""
374+
try:
375+
subscription_id = self._find_existing_subscription(query_id)
376+
del self.subscriptions[subscription_id]
377+
except Exception:
378+
pass
379+
super()._remove_listener(query_id)
380+
381+
def _find_subscription(self, query_id: int) -> Tuple[Optional[str], int]:
382+
"""Perform a reverse lookup to find the subscription id matching
383+
a listener's query_id.
384+
"""
385+
for subscription_id, subscription in self.subscriptions.items():
386+
if query_id == subscription.listener_id:
387+
return subscription_id, query_id
388+
if query_id == subscription.unsubscribe_id:
389+
return subscription_id, subscription.listener_id
390+
return None, query_id
391+
392+
def _find_existing_subscription(self, query_id: int) -> str:
393+
"""Perform a reverse lookup to find the subscription id matching
394+
a listener's query_id.
395+
"""
396+
subscription_id, _listener_id = self._find_subscription(query_id)
397+
398+
if subscription_id is None:
399+
raise TransportProtocolError(
400+
f"No subscription registered for listener {query_id}"
401+
)
402+
return subscription_id
403+
257404
async def _close_coro(self, e: Exception, clean_close: bool = True) -> None:
258405
if self.heartbeat_task is not None:
259406
self.heartbeat_task.cancel()

gql/transport/websockets.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -175,8 +175,9 @@ async def _receive(self) -> str:
175175
"""Wait the next message from the websocket connection and log the answer
176176
"""
177177

178-
# We should always have an active websocket connection here
179-
assert self.websocket is not None
178+
# It is possible that the websocket has been already closed in another task
179+
if self.websocket is None:
180+
raise TransportClosed("Transport is already closed")
180181

181182
# Wait for the next websocket frame. Can raise ConnectionClosed
182183
data: Data = await self.websocket.recv()
@@ -387,6 +388,8 @@ async def _receive_data_loop(self) -> None:
387388
except (ConnectionClosed, TransportProtocolError) as e:
388389
await self._fail(e, clean_close=False)
389390
break
391+
except TransportClosed:
392+
break
390393

391394
# Parse the answer
392395
try:
@@ -483,15 +486,14 @@ async def subscribe(
483486
break
484487

485488
except (asyncio.CancelledError, GeneratorExit) as e:
486-
log.debug("Exception in subscribe: " + repr(e))
489+
log.debug(f"Exception in subscribe: {e!r}")
487490
if listener.send_stop:
488491
await self._send_stop_message(query_id)
489492
listener.send_stop = False
490493

491494
finally:
492-
del self.listeners[query_id]
493-
if len(self.listeners) == 0:
494-
self._no_more_listeners.set()
495+
log.debug(f"In subscribe finally for query_id {query_id}")
496+
self._remove_listener(query_id)
495497

496498
async def execute(
497499
self,
@@ -609,6 +611,19 @@ async def connect(self) -> None:
609611

610612
log.debug("connect: done")
611613

614+
def _remove_listener(self, query_id) -> None:
615+
"""After exiting from a subscription, remove the listener and
616+
signal an event if this was the last listener for the client.
617+
"""
618+
if query_id in self.listeners:
619+
del self.listeners[query_id]
620+
621+
remaining = len(self.listeners)
622+
log.debug(f"listener {query_id} deleted, {remaining} remaining")
623+
624+
if remaining == 0:
625+
self._no_more_listeners.set()
626+
612627
async def _clean_close(self, e: Exception) -> None:
613628
"""Coroutine which will:
614629
@@ -627,7 +642,7 @@ async def _clean_close(self, e: Exception) -> None:
627642
try:
628643
await asyncio.wait_for(self._no_more_listeners.wait(), self.close_timeout)
629644
except asyncio.TimeoutError: # pragma: no cover
630-
pass
645+
log.debug("Timer close_timeout fired")
631646

632647
# Finally send the 'connection_terminate' message
633648
await self._send_connection_terminate_message()

0 commit comments

Comments
 (0)