54
54
Callable ,
55
55
Collection ,
56
56
Dict ,
57
+ Iterable ,
57
58
List ,
58
59
Optional ,
59
60
Set ,
61
+ Tuple ,
60
62
)
61
63
62
- from synapse .appservice import ApplicationService , ApplicationServiceState
64
+ from synapse .appservice import (
65
+ ApplicationService ,
66
+ ApplicationServiceState ,
67
+ TransactionOneTimeKeyCounts ,
68
+ TransactionUnusedFallbackKeys ,
69
+ )
63
70
from synapse .appservice .api import ApplicationServiceApi
64
71
from synapse .events import EventBase
65
72
from synapse .logging .context import run_in_background
@@ -96,7 +103,7 @@ def __init__(self, hs: "HomeServer"):
96
103
self .as_api = hs .get_application_service_api ()
97
104
98
105
self .txn_ctrl = _TransactionController (self .clock , self .store , self .as_api )
99
- self .queuer = _ServiceQueuer (self .txn_ctrl , self .clock )
106
+ self .queuer = _ServiceQueuer (self .txn_ctrl , self .clock , hs )
100
107
101
108
async def start (self ) -> None :
102
109
logger .info ("Starting appservice scheduler" )
@@ -153,7 +160,9 @@ class _ServiceQueuer:
153
160
appservice at a given time.
154
161
"""
155
162
156
- def __init__ (self , txn_ctrl : "_TransactionController" , clock : Clock ):
163
+ def __init__ (
164
+ self , txn_ctrl : "_TransactionController" , clock : Clock , hs : "HomeServer"
165
+ ):
157
166
# dict of {service_id: [events]}
158
167
self .queued_events : Dict [str , List [EventBase ]] = {}
159
168
# dict of {service_id: [events]}
@@ -165,6 +174,10 @@ def __init__(self, txn_ctrl: "_TransactionController", clock: Clock):
165
174
self .requests_in_flight : Set [str ] = set ()
166
175
self .txn_ctrl = txn_ctrl
167
176
self .clock = clock
177
+ self ._msc3202_transaction_extensions_enabled : bool = (
178
+ hs .config .experimental .msc3202_transaction_extensions
179
+ )
180
+ self ._store = hs .get_datastores ().main
168
181
169
182
def start_background_request (self , service : ApplicationService ) -> None :
170
183
# start a sender for this appservice if we don't already have one
@@ -202,15 +215,84 @@ async def _send_request(self, service: ApplicationService) -> None:
202
215
if not events and not ephemeral and not to_device_messages_to_send :
203
216
return
204
217
218
+ one_time_key_counts : Optional [TransactionOneTimeKeyCounts ] = None
219
+ unused_fallback_keys : Optional [TransactionUnusedFallbackKeys ] = None
220
+
221
+ if (
222
+ self ._msc3202_transaction_extensions_enabled
223
+ and service .msc3202_transaction_extensions
224
+ ):
225
+ # Compute the one-time key counts and fallback key usage states
226
+ # for the users which are mentioned in this transaction,
227
+ # as well as the appservice's sender.
228
+ (
229
+ one_time_key_counts ,
230
+ unused_fallback_keys ,
231
+ ) = await self ._compute_msc3202_otk_counts_and_fallback_keys (
232
+ service , events , ephemeral , to_device_messages_to_send
233
+ )
234
+
205
235
try :
206
236
await self .txn_ctrl .send (
207
- service , events , ephemeral , to_device_messages_to_send
237
+ service ,
238
+ events ,
239
+ ephemeral ,
240
+ to_device_messages_to_send ,
241
+ one_time_key_counts ,
242
+ unused_fallback_keys ,
208
243
)
209
244
except Exception :
210
245
logger .exception ("AS request failed" )
211
246
finally :
212
247
self .requests_in_flight .discard (service .id )
213
248
249
+ async def _compute_msc3202_otk_counts_and_fallback_keys (
250
+ self ,
251
+ service : ApplicationService ,
252
+ events : Iterable [EventBase ],
253
+ ephemerals : Iterable [JsonDict ],
254
+ to_device_messages : Iterable [JsonDict ],
255
+ ) -> Tuple [TransactionOneTimeKeyCounts , TransactionUnusedFallbackKeys ]:
256
+ """
257
+ Given a list of the events, ephemeral messages and to-device messages,
258
+ - first computes a list of application services users that may have
259
+ interesting updates to the one-time key counts or fallback key usage.
260
+ - then computes one-time key counts and fallback key usages for those users.
261
+ Given a list of application service users that are interesting,
262
+ compute one-time key counts and fallback key usages for the users.
263
+ """
264
+
265
+ # Set of 'interesting' users who may have updates
266
+ users : Set [str ] = set ()
267
+
268
+ # The sender is always included
269
+ users .add (service .sender )
270
+
271
+ # All AS users that would receive the PDUs or EDUs sent to these rooms
272
+ # are classed as 'interesting'.
273
+ rooms_of_interesting_users : Set [str ] = set ()
274
+ # PDUs
275
+ rooms_of_interesting_users .update (event .room_id for event in events )
276
+ # EDUs
277
+ rooms_of_interesting_users .update (
278
+ ephemeral ["room_id" ] for ephemeral in ephemerals
279
+ )
280
+
281
+ # Look up the AS users in those rooms
282
+ for room_id in rooms_of_interesting_users :
283
+ users .update (
284
+ await self ._store .get_app_service_users_in_room (room_id , service )
285
+ )
286
+
287
+ # Add recipients of to-device messages.
288
+ # device_message["user_id"] is the ID of the recipient.
289
+ users .update (device_message ["user_id" ] for device_message in to_device_messages )
290
+
291
+ # Compute and return the counts / fallback key usage states
292
+ otk_counts = await self ._store .count_bulk_e2e_one_time_keys_for_as (users )
293
+ unused_fbks = await self ._store .get_e2e_bulk_unused_fallback_key_types (users )
294
+ return otk_counts , unused_fbks
295
+
214
296
215
297
class _TransactionController :
216
298
"""Transaction manager.
@@ -238,6 +320,8 @@ async def send(
238
320
events : List [EventBase ],
239
321
ephemeral : Optional [List [JsonDict ]] = None ,
240
322
to_device_messages : Optional [List [JsonDict ]] = None ,
323
+ one_time_key_counts : Optional [TransactionOneTimeKeyCounts ] = None ,
324
+ unused_fallback_keys : Optional [TransactionUnusedFallbackKeys ] = None ,
241
325
) -> None :
242
326
"""
243
327
Create a transaction with the given data and send to the provided
@@ -248,13 +332,19 @@ async def send(
248
332
events: The persistent events to include in the transaction.
249
333
ephemeral: The ephemeral events to include in the transaction.
250
334
to_device_messages: The to-device messages to include in the transaction.
335
+ one_time_key_counts: Counts of remaining one-time keys for relevant
336
+ appservice devices in the transaction.
337
+ unused_fallback_keys: Lists of unused fallback keys for relevant
338
+ appservice devices in the transaction.
251
339
"""
252
340
try :
253
341
txn = await self .store .create_appservice_txn (
254
342
service = service ,
255
343
events = events ,
256
344
ephemeral = ephemeral or [],
257
345
to_device_messages = to_device_messages or [],
346
+ one_time_key_counts = one_time_key_counts or {},
347
+ unused_fallback_keys = unused_fallback_keys or {},
258
348
)
259
349
service_is_up = await self ._is_service_up (service )
260
350
if service_is_up :
0 commit comments