@@ -135,13 +135,6 @@ async def query_keys(
135
135
136
136
store_queries = []
137
137
for server_name , key_ids in query .items ():
138
- if (
139
- self .federation_domain_whitelist is not None
140
- and server_name not in self .federation_domain_whitelist
141
- ):
142
- logger .debug ("Federation denied with %s" , server_name )
143
- continue
144
-
145
138
if not key_ids :
146
139
key_ids = (None ,)
147
140
for key_id in key_ids :
@@ -153,21 +146,28 @@ async def query_keys(
153
146
154
147
time_now_ms = self .clock .time_msec ()
155
148
156
- # Note that the value is unused.
149
+ # Map server_name->key_id->int. Note that the value of the init is unused.
150
+ # XXX: why don't we just use a set?
157
151
cache_misses : Dict [str , Dict [str , int ]] = {}
158
152
for (server_name , key_id , _ ), key_results in cached .items ():
159
153
results = [(result ["ts_added_ms" ], result ) for result in key_results ]
160
154
161
- if not results and key_id is not None :
162
- cache_misses .setdefault (server_name , {})[key_id ] = 0
155
+ if key_id is None :
156
+ # all keys were requested. Just return what we have without worrying
157
+ # about validity
158
+ for _ , result in results :
159
+ # Cast to bytes since postgresql returns a memoryview.
160
+ json_results .add (bytes (result ["key_json" ]))
163
161
continue
164
162
165
- if key_id is not None :
163
+ miss = False
164
+ if not results :
165
+ miss = True
166
+ else :
166
167
ts_added_ms , most_recent_result = max (results )
167
168
ts_valid_until_ms = most_recent_result ["ts_valid_until_ms" ]
168
169
req_key = query .get (server_name , {}).get (key_id , {})
169
170
req_valid_until = req_key .get ("minimum_valid_until_ts" )
170
- miss = False
171
171
if req_valid_until is not None :
172
172
if ts_valid_until_ms < req_valid_until :
173
173
logger .debug (
@@ -211,19 +211,20 @@ async def query_keys(
211
211
ts_valid_until_ms ,
212
212
time_now_ms ,
213
213
)
214
-
215
- if miss :
216
- cache_misses .setdefault (server_name , {})[key_id ] = 0
217
214
# Cast to bytes since postgresql returns a memoryview.
218
215
json_results .add (bytes (most_recent_result ["key_json" ]))
219
- else :
220
- for _ , result in results :
221
- # Cast to bytes since postgresql returns a memoryview.
222
- json_results .add (bytes (result ["key_json" ]))
216
+
217
+ if miss and query_remote_on_cache_miss :
218
+ # only bother attempting to fetch keys from servers on our whitelist
219
+ if (
220
+ self .federation_domain_whitelist is None
221
+ or server_name in self .federation_domain_whitelist
222
+ ):
223
+ cache_misses .setdefault (server_name , {})[key_id ] = 0
223
224
224
225
# If there is a cache miss, request the missing keys, then recurse (and
225
226
# ensure the result is sent).
226
- if cache_misses and query_remote_on_cache_miss :
227
+ if cache_misses :
227
228
await yieldable_gather_results (
228
229
lambda t : self .fetcher .get_keys (* t ),
229
230
(
0 commit comments