1
1
import asyncio
2
2
import json
3
+ import logging
3
4
from typing import Any , Dict , Optional , Tuple
4
5
5
6
from graphql import DocumentNode , ExecutionResult , print_ast
12
13
)
13
14
from .websockets import WebsocketsTransport
14
15
16
+ log = logging .getLogger (__name__ )
17
+
18
+
19
+ class Subscription :
20
+ """Records listener_id and unsubscribe query_id for a subscription."""
21
+
22
+ def __init__ (self , query_id : int ) -> None :
23
+ self .listener_id : int = query_id
24
+ self .unsubscribe_id : Optional [int ] = None
25
+
15
26
16
27
class PhoenixChannelWebsocketsTransport (WebsocketsTransport ):
17
28
"""The PhoenixChannelWebsocketsTransport is an **EXPERIMENTAL** async transport
@@ -24,17 +35,23 @@ class PhoenixChannelWebsocketsTransport(WebsocketsTransport):
24
35
"""
25
36
26
37
def __init__ (
27
- self , channel_name : str , heartbeat_interval : float = 30 , * args , ** kwargs
38
+ self ,
39
+ channel_name : str = "__absinthe__:control" ,
40
+ heartbeat_interval : float = 30 ,
41
+ * args ,
42
+ ** kwargs ,
28
43
) -> None :
29
44
"""Initialize the transport with the given parameters.
30
45
31
- :param channel_name: Channel on the server this transport will join
46
+ :param channel_name: Channel on the server this transport will join.
47
+ The default for Absinthe servers is "__absinthe__:control"
32
48
:param heartbeat_interval: Interval in second between each heartbeat messages
33
49
sent by the client
34
50
"""
35
- self .channel_name = channel_name
36
- self .heartbeat_interval = heartbeat_interval
37
- self .subscription_ids_to_query_ids : Dict [str , int ] = {}
51
+ self .channel_name : str = channel_name
52
+ self .heartbeat_interval : float = heartbeat_interval
53
+ self .heartbeat_task : Optional [asyncio .Future ] = None
54
+ self .subscriptions : Dict [str , Subscription ] = {}
38
55
super (PhoenixChannelWebsocketsTransport , self ).__init__ (* args , ** kwargs )
39
56
40
57
async def _send_init_message_and_wait_ack (self ) -> None :
@@ -90,14 +107,32 @@ async def heartbeat_coro():
90
107
self .heartbeat_task = asyncio .ensure_future (heartbeat_coro ())
91
108
92
109
async def _send_stop_message (self , query_id : int ) -> None :
93
- try :
94
- await self .listeners [query_id ].put (("complete" , None ))
95
- except KeyError : # pragma: no cover
96
- pass
110
+ """Send an 'unsubscribe' message to the Phoenix Channel referencing
111
+ the listener's query_id, saving the query_id of the message.
97
112
98
- async def _send_connection_terminate_message ( self ) -> None :
99
- """Send a phx_leave message to disconnect from the provided channel .
113
+ The server should afterwards return a 'phx_reply' message with
114
+ the same query_id and subscription_id of the 'unsubscribe' request .
100
115
"""
116
+ subscription_id = self ._find_existing_subscription (query_id )
117
+
118
+ unsubscribe_query_id = self .next_query_id
119
+ self .next_query_id += 1
120
+
121
+ # Save the ref so it can be matched in the reply
122
+ self .subscriptions [subscription_id ].unsubscribe_id = unsubscribe_query_id
123
+ unsubscribe_message = json .dumps (
124
+ {
125
+ "topic" : self .channel_name ,
126
+ "event" : "unsubscribe" ,
127
+ "payload" : {"subscriptionId" : subscription_id },
128
+ "ref" : unsubscribe_query_id ,
129
+ }
130
+ )
131
+
132
+ await self ._send (unsubscribe_message )
133
+
134
+ async def _send_connection_terminate_message (self ) -> None :
135
+ """Send a phx_leave message to disconnect from the provided channel."""
101
136
102
137
query_id = self .next_query_id
103
138
self .next_query_id += 1
@@ -152,7 +187,7 @@ def _parse_answer(
152
187
153
188
Returns a list consisting of:
154
189
- the answer_type (between:
155
- 'heartbeat', ' data', 'reply', 'error ', 'close')
190
+ 'data', 'reply', 'complete ', 'close')
156
191
- the answer id (Integer) if received or None
157
192
- an execution Result if the answer_type is 'data' or None
158
193
"""
@@ -161,56 +196,129 @@ def _parse_answer(
161
196
answer_id : Optional [int ] = None
162
197
answer_type : str = ""
163
198
execution_result : Optional [ExecutionResult ] = None
199
+ subscription_id : Optional [str ] = None
200
+
201
+ def _get_value (d : Any , key : str , label : str ) -> Any :
202
+ if not isinstance (d , dict ):
203
+ raise ValueError (f"{ label } is not a dict" )
204
+
205
+ return d .get (key )
206
+
207
+ def _required_value (d : Any , key : str , label : str ) -> Any :
208
+ value = _get_value (d , key , label )
209
+ if value is None :
210
+ raise ValueError (f"null { key } in { label } " )
211
+
212
+ return value
213
+
214
+ def _required_subscription_id (
215
+ d : Any , label : str , must_exist : bool = False , must_not_exist = False
216
+ ) -> str :
217
+ subscription_id = str (_required_value (d , "subscriptionId" , label ))
218
+ if must_exist and (subscription_id not in self .subscriptions ):
219
+ raise ValueError ("unregistered subscriptionId" )
220
+ if must_not_exist and (subscription_id in self .subscriptions ):
221
+ raise ValueError ("previously registered subscriptionId" )
222
+
223
+ return subscription_id
224
+
225
+ def _validate_data_response (d : Any , label : str ) -> dict :
226
+ """Make sure query, mutation or subscription answer conforms.
227
+ The GraphQL spec says only three keys are permitted.
228
+ """
229
+ if not isinstance (d , dict ):
230
+ raise ValueError (f"{ label } is not a dict" )
231
+
232
+ keys = set (d .keys ())
233
+ invalid = keys - {"data" , "errors" , "extensions" }
234
+ if len (invalid ) > 0 :
235
+ raise ValueError (
236
+ f"{ label } contains invalid items: " + ", " .join (invalid )
237
+ )
238
+ return d
164
239
165
240
try :
166
241
json_answer = json .loads (answer )
167
242
168
- event = str (json_answer . get ( "event" ))
243
+ event = str (_required_value ( json_answer , "event" , "answer " ))
169
244
170
245
if event == "subscription:data" :
171
- payload = json_answer . get ( "payload" )
246
+ payload = _required_value ( json_answer , "payload" , "answer " )
172
247
173
- if not isinstance (payload , dict ):
174
- raise ValueError ("payload is not a dict" )
175
-
176
- subscription_id = str (payload .get ("subscriptionId" ))
177
- try :
178
- answer_id = self .subscription_ids_to_query_ids [subscription_id ]
179
- except KeyError :
180
- raise ValueError (
181
- f"subscription '{ subscription_id } ' has not been registerd"
182
- )
183
-
184
- result = payload .get ("result" )
248
+ subscription_id = _required_subscription_id (
249
+ payload , "payload" , must_exist = True
250
+ )
185
251
186
- if not isinstance (result , dict ):
187
- raise ValueError ("result is not a dict" )
252
+ result = _validate_data_response (payload .get ("result" ), "result" )
188
253
189
254
answer_type = "data"
190
255
256
+ subscription = self .subscriptions [subscription_id ]
257
+ answer_id = subscription .listener_id
258
+
191
259
execution_result = ExecutionResult (
192
- errors = payload .get ("errors" ),
193
260
data = result .get ("data" ),
194
- extensions = payload .get ("extensions" ),
261
+ errors = result .get ("errors" ),
262
+ extensions = result .get ("extensions" ),
195
263
)
196
264
197
265
elif event == "phx_reply" :
198
- answer_id = int (json_answer .get ("ref" ))
199
- payload = json_answer .get ("payload" )
200
266
201
- if not isinstance (payload , dict ):
202
- raise ValueError ("payload is not a dict" )
267
+ # Will generate a ValueError if 'ref' is not there
268
+ # or if it is not an integer
269
+ answer_id = int (_required_value (json_answer , "ref" , "answer" ))
203
270
204
- status = str ( payload . get ( "status" ) )
271
+ payload = _required_value ( json_answer , " payload" , "answer" )
205
272
206
- if status == "ok" :
273
+ status = _get_value ( payload , "status" , "payload" )
207
274
275
+ if status == "ok" :
208
276
answer_type = "reply"
209
- response = payload .get ("response" )
210
277
211
- if isinstance (response , dict ) and "subscriptionId" in response :
212
- subscription_id = str (response .get ("subscriptionId" ))
213
- self .subscription_ids_to_query_ids [subscription_id ] = answer_id
278
+ if answer_id in self .listeners :
279
+ response = _required_value (payload , "response" , "payload" )
280
+
281
+ if isinstance (response , dict ) and "subscriptionId" in response :
282
+
283
+ # Subscription answer
284
+ subscription_id = _required_subscription_id (
285
+ response , "response" , must_not_exist = True
286
+ )
287
+
288
+ self .subscriptions [subscription_id ] = Subscription (
289
+ answer_id
290
+ )
291
+
292
+ else :
293
+ # Query or mutation answer
294
+ # GraphQL spec says only three keys are permitted
295
+ response = _validate_data_response (response , "response" )
296
+
297
+ answer_type = "data"
298
+
299
+ execution_result = ExecutionResult (
300
+ data = response .get ("data" ),
301
+ errors = response .get ("errors" ),
302
+ extensions = response .get ("extensions" ),
303
+ )
304
+ else :
305
+ (
306
+ registered_subscription_id ,
307
+ listener_id ,
308
+ ) = self ._find_subscription (answer_id )
309
+ if registered_subscription_id is not None :
310
+ # Unsubscription answer
311
+ response = _required_value (payload , "response" , "payload" )
312
+ subscription_id = _required_subscription_id (
313
+ response , "response"
314
+ )
315
+
316
+ if subscription_id != registered_subscription_id :
317
+ raise ValueError ("subscription id does not match" )
318
+
319
+ answer_type = "complete"
320
+
321
+ answer_id = listener_id
214
322
215
323
elif status == "error" :
216
324
response = payload .get ("response" )
@@ -224,21 +332,28 @@ def _parse_answer(
224
332
raise TransportQueryError (
225
333
str (response .get ("reason" )), query_id = answer_id
226
334
)
227
- raise ValueError ("reply error" )
335
+ raise TransportQueryError ("reply error" , query_id = answer_id )
228
336
229
337
elif status == "timeout" :
230
338
raise TransportQueryError ("reply timeout" , query_id = answer_id )
339
+ else :
340
+ # missing or unrecognized status, just continue
341
+ pass
231
342
232
343
elif event == "phx_error" :
344
+ # Sent if the channel has crashed
345
+ # answer_id will be the "join_ref" for the channel
346
+ # answer_id = int(json_answer.get("ref"))
233
347
raise TransportServerError ("Server error" )
234
348
elif event == "phx_close" :
235
349
answer_type = "close"
236
350
else :
237
- raise ValueError
351
+ raise ValueError ( "unrecognized event" )
238
352
239
353
except ValueError as e :
354
+ log .error (f"Error parsing answer '{ answer } ': { e !r} " )
240
355
raise TransportProtocolError (
241
- "Server did not return a GraphQL result"
356
+ f "Server did not return a GraphQL result: { e !s } "
242
357
) from e
243
358
244
359
return answer_type , answer_id , execution_result
@@ -254,6 +369,38 @@ async def _handle_answer(
254
369
else :
255
370
await super ()._handle_answer (answer_type , answer_id , execution_result )
256
371
372
+ def _remove_listener (self , query_id : int ) -> None :
373
+ """If the listener was a subscription, remove that information."""
374
+ try :
375
+ subscription_id = self ._find_existing_subscription (query_id )
376
+ del self .subscriptions [subscription_id ]
377
+ except Exception :
378
+ pass
379
+ super ()._remove_listener (query_id )
380
+
381
+ def _find_subscription (self , query_id : int ) -> Tuple [Optional [str ], int ]:
382
+ """Perform a reverse lookup to find the subscription id matching
383
+ a listener's query_id.
384
+ """
385
+ for subscription_id , subscription in self .subscriptions .items ():
386
+ if query_id == subscription .listener_id :
387
+ return subscription_id , query_id
388
+ if query_id == subscription .unsubscribe_id :
389
+ return subscription_id , subscription .listener_id
390
+ return None , query_id
391
+
392
+ def _find_existing_subscription (self , query_id : int ) -> str :
393
+ """Perform a reverse lookup to find the subscription id matching
394
+ a listener's query_id.
395
+ """
396
+ subscription_id , _listener_id = self ._find_subscription (query_id )
397
+
398
+ if subscription_id is None :
399
+ raise TransportProtocolError (
400
+ f"No subscription registered for listener { query_id } "
401
+ )
402
+ return subscription_id
403
+
257
404
async def _close_coro (self , e : Exception , clean_close : bool = True ) -> None :
258
405
if self .heartbeat_task is not None :
259
406
self .heartbeat_task .cancel ()
0 commit comments