Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 400ba67

Browse files
committed
Modify StoreKeyFetcher to read from server_keys_json.
Previously this read from server_signature_keys which is only written by PerspectivesKeyFetcher. Any cached keys which were directly fetched (i.e. via ServerKeyFetcher) would not be used and would be refetched.
1 parent 9d2c890 commit 400ba67

File tree

4 files changed

+101
-32
lines changed

4 files changed

+101
-32
lines changed

changelog.d/15417.bugfix

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix a long-standing bug where cached key results which were directly fetched would not be properly re-used.

synapse/crypto/keyring.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -510,7 +510,7 @@ async def _fetch_keys(
510510
for key_id in queue_value.key_ids
511511
)
512512

513-
res = await self.store.get_server_signature_keys(key_ids_to_fetch)
513+
res = await self.store.get_server_keys_json(key_ids_to_fetch)
514514
keys: Dict[str, Dict[str, FetchKeyResult]] = {}
515515
for (server_name, key_id), key in res.items():
516516
keys.setdefault(server_name, {})[key_id] = key

synapse/storage/databases/main/keys.py

+76-4
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@
1414
# limitations under the License.
1515

1616
import itertools
17+
import json
1718
import logging
1819
from typing import Any, Dict, Iterable, List, Optional, Tuple
1920

2021
from signedjson.key import decode_verify_key_bytes
22+
from unpaddedbase64 import decode_base64
2123

2224
from synapse.storage._base import SQLBaseStore
2325
from synapse.storage.database import LoggingTransaction
@@ -63,10 +65,12 @@ def _get_keys(txn: Cursor, batch: Tuple[Tuple[str, str], ...]) -> None:
6365
"""Processes a batch of keys to fetch, and adds the result to `keys`."""
6466

6567
# batch_iter always returns tuples so it's safe to do len(batch)
66-
sql = (
67-
"SELECT server_name, key_id, verify_key, ts_valid_until_ms "
68-
"FROM server_signature_keys WHERE 1=0"
69-
) + " OR (server_name=? AND key_id=?)" * len(batch)
68+
sql = """
69+
SELECT server_name, key_id, verify_key, ts_valid_until_ms
70+
FROM server_signature_keys WHERE 1=0
71+
""" + " OR (server_name=? AND key_id=?)" * len(
72+
batch
73+
)
7074

7175
txn.execute(sql, tuple(itertools.chain.from_iterable(batch)))
7276

@@ -181,6 +185,74 @@ async def store_server_keys_json(
181185
desc="store_server_keys_json",
182186
)
183187

188+
# invalidate takes a tuple corresponding to the params of
189+
# _get_server_keys_json. _get_server_keys_json only takes one
190+
# param, which is itself the 2-tuple (server_name, key_id).
191+
self._get_server_keys_json.invalidate((((server_name, key_id),)))
192+
193+
@cached()
194+
def _get_server_keys_json(
195+
self, server_name_and_key_id: Tuple[str, str]
196+
) -> FetchKeyResult:
197+
raise NotImplementedError()
198+
199+
@cachedList(
200+
cached_method_name="_get_server_keys_json", list_name="server_name_and_key_ids"
201+
)
202+
async def get_server_keys_json(
203+
self, server_name_and_key_ids: Iterable[Tuple[str, str]]
204+
) -> Dict[Tuple[str, str], FetchKeyResult]:
205+
"""
206+
Args:
207+
server_name_and_key_ids:
208+
iterable of (server_name, key-id) tuples to fetch keys for
209+
210+
Returns:
211+
A map from (server_name, key_id) -> FetchKeyResult, or None if the
212+
key is unknown
213+
"""
214+
keys = {}
215+
216+
def _get_keys(txn: Cursor, batch: Tuple[Tuple[str, str], ...]) -> None:
217+
"""Processes a batch of keys to fetch, and adds the result to `keys`."""
218+
219+
# batch_iter always returns tuples so it's safe to do len(batch)
220+
sql = """
221+
SELECT server_name, key_id, key_json, ts_valid_until_ms
222+
FROM server_keys_json WHERE 1=0
223+
""" + " OR (server_name=? AND key_id=?)" * len(
224+
batch
225+
)
226+
227+
txn.execute(sql, tuple(itertools.chain.from_iterable(batch)))
228+
229+
for server_name, key_id, key_json_bytes, ts_valid_until_ms in txn:
230+
if ts_valid_until_ms is None:
231+
# Old keys may be stored with a ts_valid_until_ms of null,
232+
# in which case we treat this as if it was set to `0`, i.e.
233+
# it won't match key requests that define a minimum
234+
# `ts_valid_until_ms`.
235+
ts_valid_until_ms = 0
236+
237+
# The entire signed JSON response is stored in server_keys_json,
238+
# fetch out the bits needed.
239+
key_json = json.loads(key_json_bytes)
240+
key_base64 = key_json["verify_keys"][key_id]["key"]
241+
242+
keys[(server_name, key_id)] = FetchKeyResult(
243+
verify_key=decode_verify_key_bytes(
244+
key_id, decode_base64(key_base64)
245+
),
246+
valid_until_ts=ts_valid_until_ms,
247+
)
248+
249+
def _txn(txn: Cursor) -> Dict[Tuple[str, str], FetchKeyResult]:
250+
for batch in batch_iter(server_name_and_key_ids, 50):
251+
_get_keys(txn, batch)
252+
return keys
253+
254+
return await self.db_pool.runInteraction("get_server_keys_json", _txn)
255+
184256
async def get_server_keys_json_for_remote(
185257
self, server_keys: Iterable[Tuple[str, Optional[str], Optional[str]]]
186258
) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]:

tests/crypto/test_keyring.py

+23-27
Original file line numberDiff line numberDiff line change
@@ -190,10 +190,23 @@ def test_verify_json_for_server(self) -> None:
190190
kr = keyring.Keyring(self.hs)
191191

192192
key1 = signedjson.key.generate_signing_key("1")
193-
r = self.hs.get_datastores().main.store_server_signature_keys(
193+
r = self.hs.get_datastores().main.store_server_keys_json(
194194
"server9",
195-
int(time.time() * 1000),
196-
[("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), 1000))],
195+
get_key_id(key1),
196+
from_server="test",
197+
ts_now_ms=int(time.time() * 1000),
198+
ts_expires_ms=1000,
199+
# The entire response gets signed & stored, just include the bits we
200+
# care about.
201+
key_json_bytes=canonicaljson.encode_canonical_json(
202+
{
203+
"verify_keys": {
204+
get_key_id(key1): {
205+
"key": encode_verify_key_base64(get_verify_key(key1))
206+
}
207+
}
208+
}
209+
),
197210
)
198211
self.get_success(r)
199212

@@ -280,10 +293,6 @@ def test_verify_json_for_server_with_null_valid_until_ms(self) -> None:
280293
mock_fetcher = Mock()
281294
mock_fetcher.get_keys = Mock(return_value=make_awaitable({}))
282295

283-
kr = keyring.Keyring(
284-
self.hs, key_fetchers=(StoreKeyFetcher(self.hs), mock_fetcher)
285-
)
286-
287296
key1 = signedjson.key.generate_signing_key("1")
288297
r = self.hs.get_datastores().main.store_server_signature_keys(
289298
"server9",
@@ -298,27 +307,12 @@ def test_verify_json_for_server_with_null_valid_until_ms(self) -> None:
298307
json1: JsonDict = {}
299308
signedjson.sign.sign_json(json1, "server9", key1)
300309

301-
# should fail immediately on an unsigned object
302-
d = kr.verify_json_for_server("server9", {}, 0)
303-
self.get_failure(d, SynapseError)
304-
305-
# should fail on a signed object with a non-zero minimum_valid_until_ms,
306-
# as it tries to refetch the keys and fails.
307-
d = kr.verify_json_for_server("server9", json1, 500)
308-
self.get_failure(d, SynapseError)
309-
310-
# We expect the keyring tried to refetch the key once.
311-
mock_fetcher.get_keys.assert_called_once_with(
312-
"server9", [get_key_id(key1)], 500
313-
)
314-
315310
# should succeed on a signed object with a 0 minimum_valid_until_ms
316-
d = kr.verify_json_for_server(
317-
"server9",
318-
json1,
319-
0,
311+
d = self.hs.get_datastores().main.get_server_signature_keys(
312+
[("server9", get_key_id(key1))]
320313
)
321-
self.get_success(d)
314+
result = self.get_success(d)
315+
self.assertEquals(result[("server9", get_key_id(key1))].valid_until_ts, 0)
322316

323317
def test_verify_json_dedupes_key_requests(self) -> None:
324318
"""Two requests for the same key should be deduped."""
@@ -464,7 +458,9 @@ async def get_json(destination: str, path: str, **kwargs: Any) -> JsonDict:
464458
# check that the perspectives store is correctly updated
465459
lookup_triplet = (SERVER_NAME, testverifykey_id, None)
466460
key_json = self.get_success(
467-
self.hs.get_datastores().main.get_server_keys_json_for_remote([lookup_triplet])
461+
self.hs.get_datastores().main.get_server_keys_json_for_remote(
462+
[lookup_triplet]
463+
)
468464
)
469465
res_keys = key_json[lookup_triplet]
470466
self.assertEqual(len(res_keys), 1)

0 commit comments

Comments
 (0)