Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 5ed0e8c

Browse files
authored
Cache requests for user's devices from federation (#15675)
This should mitigate the issue where lots of different servers requests the same user's devices all at once.
1 parent d1693f0 commit 5ed0e8c

File tree

3 files changed

+70
-2
lines changed

3 files changed

+70
-2
lines changed

Diff for: changelog.d/15675.misc

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Cache requests for user's devices over federation.

Diff for: synapse/storage/databases/main/devices.py

+4
Original file line numberDiff line numberDiff line change
@@ -1941,6 +1941,10 @@ def _add_device_change_to_stream_txn(
19411941
user_id,
19421942
stream_ids[-1],
19431943
)
1944+
txn.call_after(
1945+
self._get_e2e_device_keys_for_federation_query_inner.invalidate,
1946+
(user_id,),
1947+
)
19441948

19451949
min_stream_id = stream_ids[0]
19461950

Diff for: synapse/storage/databases/main/end_to_end_keys.py

+65-2
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import abc
1717
from typing import (
1818
TYPE_CHECKING,
19+
Any,
1920
Collection,
2021
Dict,
2122
Iterable,
@@ -39,6 +40,7 @@
3940
TransactionUnusedFallbackKeys,
4041
)
4142
from synapse.logging.opentracing import log_kv, set_tag, trace
43+
from synapse.replication.tcp.streams._base import DeviceListsStream
4244
from synapse.storage._base import SQLBaseStore, db_to_json
4345
from synapse.storage.database import (
4446
DatabasePool,
@@ -104,6 +106,23 @@ def __init__(
104106
self.hs.config.federation.allow_device_name_lookup_over_federation
105107
)
106108

109+
def process_replication_rows(
110+
self,
111+
stream_name: str,
112+
instance_name: str,
113+
token: int,
114+
rows: Iterable[Any],
115+
) -> None:
116+
if stream_name == DeviceListsStream.NAME:
117+
for row in rows:
118+
assert isinstance(row, DeviceListsStream.DeviceListsStreamRow)
119+
if row.entity.startswith("@"):
120+
self._get_e2e_device_keys_for_federation_query_inner.invalidate(
121+
(row.entity,)
122+
)
123+
124+
super().process_replication_rows(stream_name, instance_name, token, rows)
125+
107126
async def get_e2e_device_keys_for_federation_query(
108127
self, user_id: str
109128
) -> Tuple[int, List[JsonDict]]:
@@ -114,6 +133,50 @@ async def get_e2e_device_keys_for_federation_query(
114133
"""
115134
now_stream_id = self.get_device_stream_token()
116135

136+
# We need to be careful with the caching here, as we need to always
137+
# return *all* persisted devices, however there may be a lag between a
138+
# new device being persisted and the cache being invalidated.
139+
cached_results = (
140+
self._get_e2e_device_keys_for_federation_query_inner.cache.get_immediate(
141+
user_id, None
142+
)
143+
)
144+
if cached_results is not None:
145+
# Check that there have been no new devices added by another worker
146+
# after the cache. This should be quick as there should be few rows
147+
# with a higher stream ordering.
148+
#
149+
# Note that we invalidate based on the device stream, so we only
150+
# have to check for potential invalidations after the
151+
# `now_stream_id`.
152+
sql = """
153+
SELECT user_id FROM device_lists_stream
154+
WHERE stream_id >= ? AND user_id = ?
155+
"""
156+
rows = await self.db_pool.execute(
157+
"get_e2e_device_keys_for_federation_query_check",
158+
None,
159+
sql,
160+
now_stream_id,
161+
user_id,
162+
)
163+
if not rows:
164+
# No new rows, so cache is still valid.
165+
return now_stream_id, cached_results
166+
167+
# There has, so let's invalidate the cache and run the query.
168+
self._get_e2e_device_keys_for_federation_query_inner.invalidate((user_id,))
169+
170+
results = await self._get_e2e_device_keys_for_federation_query_inner(user_id)
171+
172+
return now_stream_id, results
173+
174+
@cached(iterable=True)
175+
async def _get_e2e_device_keys_for_federation_query_inner(
176+
self, user_id: str
177+
) -> List[JsonDict]:
178+
"""Get all devices (with any device keys) for a user"""
179+
117180
devices = await self.get_e2e_device_keys_and_signatures([(user_id, None)])
118181

119182
if devices:
@@ -134,9 +197,9 @@ async def get_e2e_device_keys_for_federation_query(
134197

135198
results.append(result)
136199

137-
return now_stream_id, results
200+
return results
138201

139-
return now_stream_id, []
202+
return []
140203

141204
@trace
142205
@cancellable

0 commit comments

Comments
 (0)