Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 3cd78bb

Browse files
authored
Add support for MSC2732: olm fallback keys (#8312)
1 parent a024461 commit 3cd78bb

File tree

8 files changed

+215
-1
lines changed

8 files changed

+215
-1
lines changed

Diff for: changelog.d/8312.feature

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add support for olm fallback keys ([MSC2732](https://github.com/matrix-org/matrix-doc/pull/2732)).

Diff for: scripts/synapse_port_db

+1
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ BOOLEAN_COLUMNS = {
9090
"room_stats_state": ["is_federatable"],
9191
"local_media_repository": ["safe_from_quarantine"],
9292
"users": ["shadow_banned"],
93+
"e2e_fallback_keys_json": ["used"],
9394
}
9495

9596

Diff for: synapse/handlers/e2e_keys.py

+16
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,22 @@ async def upload_keys_for_user(self, user_id, device_id, keys):
496496
log_kv(
497497
{"message": "Did not update one_time_keys", "reason": "no keys given"}
498498
)
499+
fallback_keys = keys.get("org.matrix.msc2732.fallback_keys", None)
500+
if fallback_keys and isinstance(fallback_keys, dict):
501+
log_kv(
502+
{
503+
"message": "Updating fallback_keys for device.",
504+
"user_id": user_id,
505+
"device_id": device_id,
506+
}
507+
)
508+
await self.store.set_e2e_fallback_keys(user_id, device_id, fallback_keys)
509+
elif fallback_keys:
510+
log_kv({"message": "Did not update fallback_keys", "reason": "not a dict"})
511+
else:
512+
log_kv(
513+
{"message": "Did not update fallback_keys", "reason": "no keys given"}
514+
)
499515

500516
# the device should have been registered already, but it may have been
501517
# deleted due to a race with a DELETE request. Or we may be using an

Diff for: synapse/handlers/sync.py

+8
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,8 @@ class SyncResult:
201201
device_lists: List of user_ids whose devices have changed
202202
device_one_time_keys_count: Dict of algorithm to count for one time keys
203203
for this device
204+
device_unused_fallback_key_types: List of key types that have an unused fallback
205+
key
204206
groups: Group updates, if any
205207
"""
206208

@@ -213,6 +215,7 @@ class SyncResult:
213215
to_device = attr.ib(type=List[JsonDict])
214216
device_lists = attr.ib(type=DeviceLists)
215217
device_one_time_keys_count = attr.ib(type=JsonDict)
218+
device_unused_fallback_key_types = attr.ib(type=List[str])
216219
groups = attr.ib(type=Optional[GroupsSyncResult])
217220

218221
def __bool__(self) -> bool:
@@ -1014,10 +1017,14 @@ async def generate_sync_result(
10141017
logger.debug("Fetching OTK data")
10151018
device_id = sync_config.device_id
10161019
one_time_key_counts = {} # type: JsonDict
1020+
unused_fallback_key_types = [] # type: List[str]
10171021
if device_id:
10181022
one_time_key_counts = await self.store.count_e2e_one_time_keys(
10191023
user_id, device_id
10201024
)
1025+
unused_fallback_key_types = await self.store.get_e2e_unused_fallback_key_types(
1026+
user_id, device_id
1027+
)
10211028

10221029
logger.debug("Fetching group data")
10231030
await self._generate_sync_entry_for_groups(sync_result_builder)
@@ -1041,6 +1048,7 @@ async def generate_sync_result(
10411048
device_lists=device_lists,
10421049
groups=sync_result_builder.groups,
10431050
device_one_time_keys_count=one_time_key_counts,
1051+
device_unused_fallback_key_types=unused_fallback_key_types,
10441052
next_batch=sync_result_builder.now_token,
10451053
)
10461054

Diff for: synapse/rest/client/v2_alpha/sync.py

+1
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ async def encode_response(self, time_now, sync_result, access_token_id, filter):
236236
"leave": sync_result.groups.leave,
237237
},
238238
"device_one_time_keys_count": sync_result.device_one_time_keys_count,
239+
"org.matrix.msc2732.device_unused_fallback_key_types": sync_result.device_unused_fallback_key_types,
239240
"next_batch": await sync_result.next_batch.to_string(self.store),
240241
}
241242

Diff for: synapse/storage/databases/main/end_to_end_keys.py

+99-1
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,57 @@ def _count_e2e_one_time_keys(txn):
367367
"count_e2e_one_time_keys", _count_e2e_one_time_keys
368368
)
369369

370+
async def set_e2e_fallback_keys(
371+
self, user_id: str, device_id: str, fallback_keys: JsonDict
372+
) -> None:
373+
"""Set the user's e2e fallback keys.
374+
375+
Args:
376+
user_id: the user whose keys are being set
377+
device_id: the device whose keys are being set
378+
fallback_keys: the keys to set. This is a map from key ID (which is
379+
of the form "algorithm:id") to key data.
380+
"""
381+
# fallback_keys will usually only have one item in it, so using a for
382+
# loop (as opposed to calling simple_upsert_many_txn) won't be too bad
383+
# FIXME: make sure that only one key per algorithm is uploaded
384+
for key_id, fallback_key in fallback_keys.items():
385+
algorithm, key_id = key_id.split(":", 1)
386+
await self.db_pool.simple_upsert(
387+
"e2e_fallback_keys_json",
388+
keyvalues={
389+
"user_id": user_id,
390+
"device_id": device_id,
391+
"algorithm": algorithm,
392+
},
393+
values={
394+
"key_id": key_id,
395+
"key_json": json_encoder.encode(fallback_key),
396+
"used": False,
397+
},
398+
desc="set_e2e_fallback_key",
399+
)
400+
401+
@cached(max_entries=10000)
402+
async def get_e2e_unused_fallback_key_types(
403+
self, user_id: str, device_id: str
404+
) -> List[str]:
405+
"""Returns the fallback key types that have an unused key.
406+
407+
Args:
408+
user_id: the user whose keys are being queried
409+
device_id: the device whose keys are being queried
410+
411+
Returns:
412+
a list of key types
413+
"""
414+
return await self.db_pool.simple_select_onecol(
415+
"e2e_fallback_keys_json",
416+
keyvalues={"user_id": user_id, "device_id": device_id, "used": False},
417+
retcol="algorithm",
418+
desc="get_e2e_unused_fallback_key_types",
419+
)
420+
370421
async def get_e2e_cross_signing_key(
371422
self, user_id: str, key_type: str, from_user_id: Optional[str] = None
372423
) -> Optional[dict]:
@@ -701,15 +752,37 @@ def _claim_e2e_one_time_keys(txn):
701752
" WHERE user_id = ? AND device_id = ? AND algorithm = ?"
702753
" LIMIT 1"
703754
)
755+
fallback_sql = (
756+
"SELECT key_id, key_json, used FROM e2e_fallback_keys_json"
757+
" WHERE user_id = ? AND device_id = ? AND algorithm = ?"
758+
" LIMIT 1"
759+
)
704760
result = {}
705761
delete = []
762+
used_fallbacks = []
706763
for user_id, device_id, algorithm in query_list:
707764
user_result = result.setdefault(user_id, {})
708765
device_result = user_result.setdefault(device_id, {})
709766
txn.execute(sql, (user_id, device_id, algorithm))
710-
for key_id, key_json in txn:
767+
otk_row = txn.fetchone()
768+
if otk_row is not None:
769+
key_id, key_json = otk_row
711770
device_result[algorithm + ":" + key_id] = key_json
712771
delete.append((user_id, device_id, algorithm, key_id))
772+
else:
773+
# no one-time key available, so see if there's a fallback
774+
# key
775+
txn.execute(fallback_sql, (user_id, device_id, algorithm))
776+
fallback_row = txn.fetchone()
777+
if fallback_row is not None:
778+
key_id, key_json, used = fallback_row
779+
device_result[algorithm + ":" + key_id] = key_json
780+
if not used:
781+
used_fallbacks.append(
782+
(user_id, device_id, algorithm, key_id)
783+
)
784+
785+
# drop any one-time keys that were claimed
713786
sql = (
714787
"DELETE FROM e2e_one_time_keys_json"
715788
" WHERE user_id = ? AND device_id = ? AND algorithm = ?"
@@ -726,6 +799,23 @@ def _claim_e2e_one_time_keys(txn):
726799
self._invalidate_cache_and_stream(
727800
txn, self.count_e2e_one_time_keys, (user_id, device_id)
728801
)
802+
# mark fallback keys as used
803+
for user_id, device_id, algorithm, key_id in used_fallbacks:
804+
self.db_pool.simple_update_txn(
805+
txn,
806+
"e2e_fallback_keys_json",
807+
{
808+
"user_id": user_id,
809+
"device_id": device_id,
810+
"algorithm": algorithm,
811+
"key_id": key_id,
812+
},
813+
{"used": True},
814+
)
815+
self._invalidate_cache_and_stream(
816+
txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id)
817+
)
818+
729819
return result
730820

731821
return await self.db_pool.runInteraction(
@@ -754,6 +844,14 @@ def delete_e2e_keys_by_device_txn(txn):
754844
self._invalidate_cache_and_stream(
755845
txn, self.count_e2e_one_time_keys, (user_id, device_id)
756846
)
847+
self.db_pool.simple_delete_txn(
848+
txn,
849+
table="e2e_fallback_keys_json",
850+
keyvalues={"user_id": user_id, "device_id": device_id},
851+
)
852+
self._invalidate_cache_and_stream(
853+
txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id)
854+
)
757855

758856
await self.db_pool.runInteraction(
759857
"delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
/* Copyright 2020 The Matrix.org Foundation C.I.C
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
CREATE TABLE IF NOT EXISTS e2e_fallback_keys_json (
17+
user_id TEXT NOT NULL, -- The user this fallback key is for.
18+
device_id TEXT NOT NULL, -- The device this fallback key is for.
19+
algorithm TEXT NOT NULL, -- Which algorithm this fallback key is for.
20+
key_id TEXT NOT NULL, -- An id for suppressing duplicate uploads.
21+
key_json TEXT NOT NULL, -- The key as a JSON blob.
22+
used BOOLEAN NOT NULL DEFAULT FALSE, -- Whether the key has been used or not.
23+
CONSTRAINT e2e_fallback_keys_json_uniqueness UNIQUE (user_id, device_id, algorithm)
24+
);

Diff for: tests/handlers/test_e2e_keys.py

+65
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,71 @@ def test_claim_one_time_key(self):
171171
},
172172
)
173173

174+
@defer.inlineCallbacks
175+
def test_fallback_key(self):
176+
local_user = "@boris:" + self.hs.hostname
177+
device_id = "xyz"
178+
fallback_key = {"alg1:k1": "key1"}
179+
otk = {"alg1:k2": "key2"}
180+
181+
yield defer.ensureDeferred(
182+
self.handler.upload_keys_for_user(
183+
local_user,
184+
device_id,
185+
{"org.matrix.msc2732.fallback_keys": fallback_key},
186+
)
187+
)
188+
189+
# claiming an OTK when no OTKs are available should return the fallback
190+
# key
191+
res = yield defer.ensureDeferred(
192+
self.handler.claim_one_time_keys(
193+
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
194+
)
195+
)
196+
self.assertEqual(
197+
res,
198+
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
199+
)
200+
201+
# claiming an OTK again should return the same fallback key
202+
res = yield defer.ensureDeferred(
203+
self.handler.claim_one_time_keys(
204+
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
205+
)
206+
)
207+
self.assertEqual(
208+
res,
209+
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
210+
)
211+
212+
# if the user uploads a one-time key, the next claim should fetch the
213+
# one-time key, and then go back to the fallback
214+
yield defer.ensureDeferred(
215+
self.handler.upload_keys_for_user(
216+
local_user, device_id, {"one_time_keys": otk}
217+
)
218+
)
219+
220+
res = yield defer.ensureDeferred(
221+
self.handler.claim_one_time_keys(
222+
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
223+
)
224+
)
225+
self.assertEqual(
226+
res, {"failures": {}, "one_time_keys": {local_user: {device_id: otk}}},
227+
)
228+
229+
res = yield defer.ensureDeferred(
230+
self.handler.claim_one_time_keys(
231+
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
232+
)
233+
)
234+
self.assertEqual(
235+
res,
236+
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
237+
)
238+
174239
@defer.inlineCallbacks
175240
def test_replace_master_key(self):
176241
"""uploading a new signing key should make the old signing key unavailable"""

0 commit comments

Comments
 (0)