|
14 | 14 | # limitations under the License.
|
15 | 15 |
|
16 | 16 | import itertools
|
| 17 | +import json |
17 | 18 | import logging
|
18 | 19 | from typing import Any, Dict, Iterable, List, Optional, Tuple
|
19 | 20 |
|
20 | 21 | from signedjson.key import decode_verify_key_bytes
|
| 22 | +from unpaddedbase64 import decode_base64 |
21 | 23 |
|
22 | 24 | from synapse.storage._base import SQLBaseStore
|
23 | 25 | from synapse.storage.database import LoggingTransaction
|
@@ -63,10 +65,12 @@ def _get_keys(txn: Cursor, batch: Tuple[Tuple[str, str], ...]) -> None:
|
63 | 65 | """Processes a batch of keys to fetch, and adds the result to `keys`."""
|
64 | 66 |
|
65 | 67 | # batch_iter always returns tuples so it's safe to do len(batch)
|
66 |
| - sql = ( |
67 |
| - "SELECT server_name, key_id, verify_key, ts_valid_until_ms " |
68 |
| - "FROM server_signature_keys WHERE 1=0" |
69 |
| - ) + " OR (server_name=? AND key_id=?)" * len(batch) |
| 68 | + sql = """ |
| 69 | + SELECT server_name, key_id, verify_key, ts_valid_until_ms |
| 70 | + FROM server_signature_keys WHERE 1=0 |
| 71 | + """ + " OR (server_name=? AND key_id=?)" * len( |
| 72 | + batch |
| 73 | + ) |
70 | 74 |
|
71 | 75 | txn.execute(sql, tuple(itertools.chain.from_iterable(batch)))
|
72 | 76 |
|
@@ -181,6 +185,74 @@ async def store_server_keys_json(
|
181 | 185 | desc="store_server_keys_json",
|
182 | 186 | )
|
183 | 187 |
|
| 188 | + # invalidate takes a tuple corresponding to the params of |
| 189 | + # _get_server_keys_json. _get_server_keys_json only takes one |
| 190 | + # param, which is itself the 2-tuple (server_name, key_id). |
| 191 | + self._get_server_keys_json.invalidate((((server_name, key_id),))) |
| 192 | + |
| 193 | + @cached() |
| 194 | + def _get_server_keys_json( |
| 195 | + self, server_name_and_key_id: Tuple[str, str] |
| 196 | + ) -> FetchKeyResult: |
| 197 | + raise NotImplementedError() |
| 198 | + |
| 199 | + @cachedList( |
| 200 | + cached_method_name="_get_server_keys_json", list_name="server_name_and_key_ids" |
| 201 | + ) |
| 202 | + async def get_server_keys_json( |
| 203 | + self, server_name_and_key_ids: Iterable[Tuple[str, str]] |
| 204 | + ) -> Dict[Tuple[str, str], FetchKeyResult]: |
| 205 | + """ |
| 206 | + Args: |
| 207 | + server_name_and_key_ids: |
| 208 | + iterable of (server_name, key-id) tuples to fetch keys for |
| 209 | +
|
| 210 | + Returns: |
| 211 | + A map from (server_name, key_id) -> FetchKeyResult, or None if the |
| 212 | + key is unknown |
| 213 | + """ |
| 214 | + keys = {} |
| 215 | + |
| 216 | + def _get_keys(txn: Cursor, batch: Tuple[Tuple[str, str], ...]) -> None: |
| 217 | + """Processes a batch of keys to fetch, and adds the result to `keys`.""" |
| 218 | + |
| 219 | + # batch_iter always returns tuples so it's safe to do len(batch) |
| 220 | + sql = """ |
| 221 | + SELECT server_name, key_id, key_json, ts_valid_until_ms |
| 222 | + FROM server_keys_json WHERE 1=0 |
| 223 | + """ + " OR (server_name=? AND key_id=?)" * len( |
| 224 | + batch |
| 225 | + ) |
| 226 | + |
| 227 | + txn.execute(sql, tuple(itertools.chain.from_iterable(batch))) |
| 228 | + |
| 229 | + for server_name, key_id, key_json_bytes, ts_valid_until_ms in txn: |
| 230 | + if ts_valid_until_ms is None: |
| 231 | + # Old keys may be stored with a ts_valid_until_ms of null, |
| 232 | + # in which case we treat this as if it was set to `0`, i.e. |
| 233 | + # it won't match key requests that define a minimum |
| 234 | + # `ts_valid_until_ms`. |
| 235 | + ts_valid_until_ms = 0 |
| 236 | + |
| 237 | + # The entire signed JSON response is stored in server_keys_json, |
| 238 | + # fetch out the bits needed. |
| 239 | + key_json = json.loads(key_json_bytes) |
| 240 | + key_base64 = key_json["verify_keys"][key_id]["key"] |
| 241 | + |
| 242 | + keys[(server_name, key_id)] = FetchKeyResult( |
| 243 | + verify_key=decode_verify_key_bytes( |
| 244 | + key_id, decode_base64(key_base64) |
| 245 | + ), |
| 246 | + valid_until_ts=ts_valid_until_ms, |
| 247 | + ) |
| 248 | + |
| 249 | + def _txn(txn: Cursor) -> Dict[Tuple[str, str], FetchKeyResult]: |
| 250 | + for batch in batch_iter(server_name_and_key_ids, 50): |
| 251 | + _get_keys(txn, batch) |
| 252 | + return keys |
| 253 | + |
| 254 | + return await self.db_pool.runInteraction("get_server_keys_json", _txn) |
| 255 | + |
184 | 256 | async def get_server_keys_json_for_remote(
|
185 | 257 | self, server_keys: Iterable[Tuple[str, Optional[str], Optional[str]]]
|
186 | 258 | ) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]:
|
|
0 commit comments