Skip to content

Commit bc6c40e

Browse files
committed
Improve perf of looking up device lists changes in /sync
1 parent 5b2b312 commit bc6c40e

File tree

2 files changed

+72
-11
lines changed

2 files changed

+72
-11
lines changed

Diff for: synapse/replication/tcp/client.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,14 @@ async def on_rdata(
112112
token: stream token for this batch of rows
113113
rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row.
114114
"""
115+
all_room_ids: Set[str] = set()
116+
if stream_name == DeviceListsStream.NAME:
117+
prev_token = self.store.get_device_stream_token()
118+
all_room_ids = await self.store.get_all_device_list_changes(
119+
prev_token, token
120+
)
121+
self.store.device_lists_in_rooms_have_changed(all_room_ids, token)
122+
115123
self.store.process_replication_rows(stream_name, instance_name, token, rows)
116124
# NOTE: this must be called after process_replication_rows to ensure any
117125
# cache invalidations are first handled before any stream ID advances.
@@ -146,12 +154,6 @@ async def on_rdata(
146154
StreamKeyType.TO_DEVICE, token, users=entities
147155
)
148156
elif stream_name == DeviceListsStream.NAME:
149-
all_room_ids: Set[str] = set()
150-
for row in rows:
151-
if row.entity.startswith("@") and not row.is_signature:
152-
room_ids = await self.store.get_rooms_for_user(row.entity)
153-
all_room_ids.update(room_ids)
154-
155157
# `all_room_ids` can be large, so let's wake up those streams in batches
156158
for batched_room_ids in batch_iter(all_room_ids, 100):
157159
self.notifier.on_new_event(

Diff for: synapse/storage/databases/main/devices.py

+64-5
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,20 @@ def __init__(
129129
prefilled_cache=device_list_prefill,
130130
)
131131

132+
device_list_room_prefill, min_device_list_room_id = self.db_pool.get_cache_dict(
133+
db_conn,
134+
"device_lists_changes_in_room",
135+
entity_column="room_id",
136+
stream_column="stream_id",
137+
max_value=device_list_max,
138+
limit=10000,
139+
)
140+
self._device_list_room_stream_cache = StreamChangeCache(
141+
"DeviceListRoomStreamChangeCache",
142+
min_device_list_room_id,
143+
prefilled_cache=device_list_room_prefill,
144+
)
145+
132146
(
133147
user_signature_stream_prefill,
134148
user_signature_stream_list_id,
@@ -206,6 +220,13 @@ def _invalidate_caches_for_devices(
206220
row.entity, token
207221
)
208222

223+
def device_lists_in_rooms_have_changed(
224+
self, room_ids: StrCollection, token: int
225+
) -> None:
226+
"Record that device lists have changed in rooms"
227+
for room_id in room_ids:
228+
self._device_list_room_stream_cache.entity_has_changed(room_id, token)
229+
209230
def get_device_stream_token(self) -> int:
210231
return self._device_list_id_gen.get_current_token()
211232

@@ -1460,6 +1481,12 @@ async def get_device_list_changes_in_rooms(
14601481
if min_stream_id > from_id:
14611482
return None
14621483

1484+
changed_room_ids = self._device_list_room_stream_cache.get_entities_changed(
1485+
room_ids, from_id
1486+
)
1487+
if not changed_room_ids:
1488+
return set()
1489+
14631490
sql = """
14641491
SELECT DISTINCT user_id FROM device_lists_changes_in_room
14651492
WHERE {clause} AND stream_id > ? AND stream_id <= ?
@@ -1474,7 +1501,7 @@ def _get_device_list_changes_in_rooms_txn(
14741501
return {user_id for user_id, in txn}
14751502

14761503
changes = set()
1477-
for chunk in batch_iter(room_ids, 1000):
1504+
for chunk in batch_iter(changed_room_ids, 1000):
14781505
clause, args = make_in_list_sql_clause(
14791506
self.database_engine, "room_id", chunk
14801507
)
@@ -1490,6 +1517,34 @@ def _get_device_list_changes_in_rooms_txn(
14901517

14911518
return changes
14921519

1520+
async def get_all_device_list_changes(self, from_id: int, to_id: int) -> Set[str]:
1521+
"""Return the set of rooms where devices have changed since the given
1522+
stream ID.
1523+
1524+
Will raise an exception if the given stream ID is too old.
1525+
"""
1526+
1527+
min_stream_id = await self._get_min_device_lists_changes_in_room()
1528+
1529+
if min_stream_id > from_id:
1530+
raise Exception("stream ID is too old")
1531+
1532+
sql = """
1533+
SELECT DISTINCT room_id FROM device_lists_changes_in_room
1534+
WHERE stream_id > ? AND stream_id <= ?
1535+
"""
1536+
1537+
def _get_all_device_list_changes_txn(
1538+
txn: LoggingTransaction,
1539+
) -> Set[str]:
1540+
txn.execute(sql, (from_id, to_id))
1541+
return {room_id for room_id, in txn}
1542+
1543+
return await self.db_pool.runInteraction(
1544+
"get_all_device_list_changes",
1545+
_get_all_device_list_changes_txn,
1546+
)
1547+
14931548
async def get_device_list_changes_in_room(
14941549
self, room_id: str, min_stream_id: int
14951550
) -> Collection[Tuple[str, str]]:
@@ -1950,8 +2005,8 @@ def _update_remote_device_list_cache_txn(
19502005
async def add_device_change_to_streams(
19512006
self,
19522007
user_id: str,
1953-
device_ids: Collection[str],
1954-
room_ids: Collection[str],
2008+
device_ids: StrCollection,
2009+
room_ids: StrCollection,
19552010
) -> Optional[int]:
19562011
"""Persist that a user's devices have been updated, and which hosts
19572012
(if any) should be poked.
@@ -2110,8 +2165,8 @@ def _add_device_outbound_room_poke_txn(
21102165
self,
21112166
txn: LoggingTransaction,
21122167
user_id: str,
2113-
device_ids: Iterable[str],
2114-
room_ids: Collection[str],
2168+
device_ids: StrCollection,
2169+
room_ids: StrCollection,
21152170
stream_ids: List[int],
21162171
context: Dict[str, str],
21172172
) -> None:
@@ -2149,6 +2204,10 @@ def _add_device_outbound_room_poke_txn(
21492204
],
21502205
)
21512206

2207+
txn.call_after(
2208+
self.device_lists_in_rooms_have_changed, room_ids, max(stream_ids)
2209+
)
2210+
21522211
async def get_uncoverted_outbound_room_pokes(
21532212
self, start_stream_id: int, start_room_id: str, limit: int = 10
21542213
) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]:

0 commit comments

Comments
 (0)