16
16
import abc
17
17
from typing import (
18
18
TYPE_CHECKING ,
19
+ Any ,
19
20
Collection ,
20
21
Dict ,
21
22
Iterable ,
39
40
TransactionUnusedFallbackKeys ,
40
41
)
41
42
from synapse .logging .opentracing import log_kv , set_tag , trace
43
+ from synapse .replication .tcp .streams ._base import DeviceListsStream
42
44
from synapse .storage ._base import SQLBaseStore , db_to_json
43
45
from synapse .storage .database import (
44
46
DatabasePool ,
@@ -104,6 +106,23 @@ def __init__(
104
106
self .hs .config .federation .allow_device_name_lookup_over_federation
105
107
)
106
108
109
+ def process_replication_rows (
110
+ self ,
111
+ stream_name : str ,
112
+ instance_name : str ,
113
+ token : int ,
114
+ rows : Iterable [Any ],
115
+ ) -> None :
116
+ if stream_name == DeviceListsStream .NAME :
117
+ for row in rows :
118
+ assert isinstance (row , DeviceListsStream .DeviceListsStreamRow )
119
+ if row .entity .startswith ("@" ):
120
+ self ._get_e2e_device_keys_for_federation_query_inner .invalidate (
121
+ (row .entity ,)
122
+ )
123
+
124
+ super ().process_replication_rows (stream_name , instance_name , token , rows )
125
+
107
126
async def get_e2e_device_keys_for_federation_query (
108
127
self , user_id : str
109
128
) -> Tuple [int , List [JsonDict ]]:
@@ -114,6 +133,50 @@ async def get_e2e_device_keys_for_federation_query(
114
133
"""
115
134
now_stream_id = self .get_device_stream_token ()
116
135
136
+ # We need to be careful with the caching here, as we need to always
137
+ # return *all* persisted devices, however there may be a lag between a
138
+ # new device being persisted and the cache being invalidated.
139
+ cached_results = (
140
+ self ._get_e2e_device_keys_for_federation_query_inner .cache .get_immediate (
141
+ user_id , None
142
+ )
143
+ )
144
+ if cached_results is not None :
145
+ # Check that there have been no new devices added by another worker
146
+ # after the cache. This should be quick as there should be few rows
147
+ # with a higher stream ordering.
148
+ #
149
+ # Note that we invalidate based on the device stream, so we only
150
+ # have to check for potential invalidations after the
151
+ # `now_stream_id`.
152
+ sql = """
153
+ SELECT user_id FROM device_lists_stream
154
+ WHERE stream_id >= ? AND user_id = ?
155
+ """
156
+ rows = await self .db_pool .execute (
157
+ "get_e2e_device_keys_for_federation_query_check" ,
158
+ None ,
159
+ sql ,
160
+ now_stream_id ,
161
+ user_id ,
162
+ )
163
+ if not rows :
164
+ # No new rows, so cache is still valid.
165
+ return now_stream_id , cached_results
166
+
167
+ # There has, so let's invalidate the cache and run the query.
168
+ self ._get_e2e_device_keys_for_federation_query_inner .invalidate ((user_id ,))
169
+
170
+ results = await self ._get_e2e_device_keys_for_federation_query_inner (user_id )
171
+
172
+ return now_stream_id , results
173
+
174
+ @cached (iterable = True )
175
+ async def _get_e2e_device_keys_for_federation_query_inner (
176
+ self , user_id : str
177
+ ) -> List [JsonDict ]:
178
+ """Get all devices (with any device keys) for a user"""
179
+
117
180
devices = await self .get_e2e_device_keys_and_signatures ([(user_id , None )])
118
181
119
182
if devices :
@@ -134,9 +197,9 @@ async def get_e2e_device_keys_for_federation_query(
134
197
135
198
results .append (result )
136
199
137
- return now_stream_id , results
200
+ return results
138
201
139
- return now_stream_id , []
202
+ return []
140
203
141
204
@trace
142
205
@cancellable
0 commit comments