@@ -129,6 +129,20 @@ def __init__(
129
129
prefilled_cache = device_list_prefill ,
130
130
)
131
131
132
+ device_list_room_prefill , min_device_list_room_id = self .db_pool .get_cache_dict (
133
+ db_conn ,
134
+ "device_lists_changes_in_room" ,
135
+ entity_column = "room_id" ,
136
+ stream_column = "stream_id" ,
137
+ max_value = device_list_max ,
138
+ limit = 10000 ,
139
+ )
140
+ self ._device_list_room_stream_cache = StreamChangeCache (
141
+ "DeviceListRoomStreamChangeCache" ,
142
+ min_device_list_room_id ,
143
+ prefilled_cache = device_list_room_prefill ,
144
+ )
145
+
132
146
(
133
147
user_signature_stream_prefill ,
134
148
user_signature_stream_list_id ,
@@ -206,6 +220,13 @@ def _invalidate_caches_for_devices(
206
220
row .entity , token
207
221
)
208
222
223
+ def device_lists_in_rooms_have_changed (
224
+ self , room_ids : StrCollection , token : int
225
+ ) -> None :
226
+ "Record that device lists have changed in rooms"
227
+ for room_id in room_ids :
228
+ self ._device_list_room_stream_cache .entity_has_changed (room_id , token )
229
+
209
230
def get_device_stream_token (self ) -> int :
210
231
return self ._device_list_id_gen .get_current_token ()
211
232
@@ -1460,6 +1481,12 @@ async def get_device_list_changes_in_rooms(
1460
1481
if min_stream_id > from_id :
1461
1482
return None
1462
1483
1484
+ changed_room_ids = self ._device_list_room_stream_cache .get_entities_changed (
1485
+ room_ids , from_id
1486
+ )
1487
+ if not changed_room_ids :
1488
+ return set ()
1489
+
1463
1490
sql = """
1464
1491
SELECT DISTINCT user_id FROM device_lists_changes_in_room
1465
1492
WHERE {clause} AND stream_id > ? AND stream_id <= ?
@@ -1474,7 +1501,7 @@ def _get_device_list_changes_in_rooms_txn(
1474
1501
return {user_id for user_id , in txn }
1475
1502
1476
1503
changes = set ()
1477
- for chunk in batch_iter (room_ids , 1000 ):
1504
+ for chunk in batch_iter (changed_room_ids , 1000 ):
1478
1505
clause , args = make_in_list_sql_clause (
1479
1506
self .database_engine , "room_id" , chunk
1480
1507
)
@@ -1490,6 +1517,34 @@ def _get_device_list_changes_in_rooms_txn(
1490
1517
1491
1518
return changes
1492
1519
1520
+ async def get_all_device_list_changes (self , from_id : int , to_id : int ) -> Set [str ]:
1521
+ """Return the set of rooms where devices have changed since the given
1522
+ stream ID.
1523
+
1524
+ Will raise an exception if the given stream ID is too old.
1525
+ """
1526
+
1527
+ min_stream_id = await self ._get_min_device_lists_changes_in_room ()
1528
+
1529
+ if min_stream_id > from_id :
1530
+ raise Exception ("stream ID is too old" )
1531
+
1532
+ sql = """
1533
+ SELECT DISTINCT room_id FROM device_lists_changes_in_room
1534
+ WHERE stream_id > ? AND stream_id <= ?
1535
+ """
1536
+
1537
+ def _get_all_device_list_changes_txn (
1538
+ txn : LoggingTransaction ,
1539
+ ) -> Set [str ]:
1540
+ txn .execute (sql , (from_id , to_id ))
1541
+ return {room_id for room_id , in txn }
1542
+
1543
+ return await self .db_pool .runInteraction (
1544
+ "get_all_device_list_changes" ,
1545
+ _get_all_device_list_changes_txn ,
1546
+ )
1547
+
1493
1548
async def get_device_list_changes_in_room (
1494
1549
self , room_id : str , min_stream_id : int
1495
1550
) -> Collection [Tuple [str , str ]]:
@@ -1950,8 +2005,8 @@ def _update_remote_device_list_cache_txn(
1950
2005
async def add_device_change_to_streams (
1951
2006
self ,
1952
2007
user_id : str ,
1953
- device_ids : Collection [ str ] ,
1954
- room_ids : Collection [ str ] ,
2008
+ device_ids : StrCollection ,
2009
+ room_ids : StrCollection ,
1955
2010
) -> Optional [int ]:
1956
2011
"""Persist that a user's devices have been updated, and which hosts
1957
2012
(if any) should be poked.
@@ -2110,8 +2165,8 @@ def _add_device_outbound_room_poke_txn(
2110
2165
self ,
2111
2166
txn : LoggingTransaction ,
2112
2167
user_id : str ,
2113
- device_ids : Iterable [ str ] ,
2114
- room_ids : Collection [ str ] ,
2168
+ device_ids : StrCollection ,
2169
+ room_ids : StrCollection ,
2115
2170
stream_ids : List [int ],
2116
2171
context : Dict [str , str ],
2117
2172
) -> None :
@@ -2149,6 +2204,10 @@ def _add_device_outbound_room_poke_txn(
2149
2204
],
2150
2205
)
2151
2206
2207
+ txn .call_after (
2208
+ self .device_lists_in_rooms_have_changed , room_ids , max (stream_ids )
2209
+ )
2210
+
2152
2211
async def get_uncoverted_outbound_room_pokes (
2153
2212
self , start_stream_id : int , start_room_id : str , limit : int = 10
2154
2213
) -> List [Tuple [str , str , str , int , Optional [Dict [str , str ]]]]:
0 commit comments