Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 5282ba1

Browse files
authored
Implement MSC3983 to proxy /keys/claim queries to appservices. (#15314)
Experimental support for MSC3983 is behind a configuration flag. If enabled, for users which are exclusively owned by an application service then the appservice will be queried for one-time keys *if* there are none uploaded to Synapse.
1 parent 57481ca commit 5282ba1

File tree

9 files changed

+355
-29
lines changed

9 files changed

+355
-29
lines changed

changelog.d/15314.feature

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Experimental support for passing One Time Key requests to application services ([MSC3983](https://github.com/matrix-org/matrix-spec-proposals/pull/3983)).

synapse/appservice/api.py

+56
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,62 @@ async def push_bulk(
388388
failed_transactions_counter.labels(service.id).inc()
389389
return False
390390

391+
async def claim_client_keys(
392+
self, service: "ApplicationService", query: List[Tuple[str, str, str]]
393+
) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]:
394+
"""Claim one time keys from an application service.
395+
396+
Args:
397+
query: An iterable of tuples of (user ID, device ID, algorithm).
398+
399+
Returns:
400+
A tuple of:
401+
A map of user ID -> a map device ID -> a map of key ID -> JSON dict.
402+
403+
A copy of the input which has not been fulfilled because the
404+
appservice doesn't support this endpoint or has not returned
405+
data for that tuple.
406+
"""
407+
if service.url is None:
408+
return {}, query
409+
410+
# This is required by the configuration.
411+
assert service.hs_token is not None
412+
413+
# Create the expected payload shape.
414+
body: Dict[str, Dict[str, List[str]]] = {}
415+
for user_id, device, algorithm in query:
416+
body.setdefault(user_id, {}).setdefault(device, []).append(algorithm)
417+
418+
uri = f"{service.url}/_matrix/app/unstable/org.matrix.msc3983/keys/claim"
419+
try:
420+
response = await self.post_json_get_json(
421+
uri,
422+
body,
423+
headers={"Authorization": [f"Bearer {service.hs_token}"]},
424+
)
425+
except CodeMessageException as e:
426+
# The appservice doesn't support this endpoint.
427+
if e.code == 404 or e.code == 405:
428+
return {}, query
429+
logger.warning("claim_keys to %s received %s", uri, e.code)
430+
return {}, query
431+
except Exception as ex:
432+
logger.warning("claim_keys to %s threw exception %s", uri, ex)
433+
return {}, query
434+
435+
# Check if the appservice fulfilled all of the queried user/device/algorithms
436+
# or if some are still missing.
437+
#
438+
# TODO This places a lot of faith in the response shape being correct.
439+
missing = [
440+
(user_id, device, algorithm)
441+
for user_id, device, algorithm in query
442+
if algorithm not in response.get(user_id, {}).get(device, [])
443+
]
444+
445+
return response, missing
446+
391447
def _serialize(
392448
self, service: "ApplicationService", events: Iterable[EventBase]
393449
) -> List[JsonDict]:

synapse/config/experimental.py

+5
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,11 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
7474
"msc3202_transaction_extensions", False
7575
)
7676

77+
# MSC3983: Proxying OTK claim requests to exclusive ASes.
78+
self.msc3983_appservice_otk_claims: bool = experimental.get(
79+
"msc3983_appservice_otk_claims", False
80+
)
81+
7782
# MSC3706 (server-side support for partial state in /send_join responses)
7883
# Synapse will always serve partial state responses to requests using the stable
7984
# query parameter `omit_members`. If this flag is set, Synapse will also serve

synapse/federation/federation_server.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@
8686
from synapse.storage.databases.main.roommember import extract_heroes_from_room_summary
8787
from synapse.storage.roommember import MemberSummary
8888
from synapse.types import JsonDict, StateMap, get_domain_from_id
89-
from synapse.util import json_decoder, unwrapFirstError
89+
from synapse.util import unwrapFirstError
9090
from synapse.util.async_helpers import Linearizer, concurrently_execute, gather_results
9191
from synapse.util.caches.response_cache import ResponseCache
9292
from synapse.util.stringutils import parse_server_name
@@ -135,6 +135,7 @@ def __init__(self, hs: "HomeServer"):
135135
self.state = hs.get_state_handler()
136136
self._event_auth_handler = hs.get_event_auth_handler()
137137
self._room_member_handler = hs.get_room_member_handler()
138+
self._e2e_keys_handler = hs.get_e2e_keys_handler()
138139

139140
self._state_storage_controller = hs.get_storage_controllers().state
140141

@@ -1012,15 +1013,14 @@ async def on_claim_client_keys(
10121013
query.append((user_id, device_id, algorithm))
10131014

10141015
log_kv({"message": "Claiming one time keys.", "user, device pairs": query})
1015-
results = await self.store.claim_e2e_one_time_keys(query)
1016-
1017-
json_result: Dict[str, Dict[str, dict]] = {}
1018-
for user_id, device_keys in results.items():
1019-
for device_id, keys in device_keys.items():
1020-
for key_id, json_str in keys.items():
1021-
json_result.setdefault(user_id, {})[device_id] = {
1022-
key_id: json_decoder.decode(json_str)
1023-
}
1016+
results = await self._e2e_keys_handler.claim_local_one_time_keys(query)
1017+
1018+
json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
1019+
for result in results:
1020+
for user_id, device_keys in result.items():
1021+
for device_id, keys in device_keys.items():
1022+
for key_id, key in keys.items():
1023+
json_result.setdefault(user_id, {})[device_id] = {key_id: key}
10241024

10251025
logger.info(
10261026
"Claimed one-time-keys: %s",

synapse/handlers/appservice.py

+73-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,16 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import logging
15-
from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Union
15+
from typing import (
16+
TYPE_CHECKING,
17+
Collection,
18+
Dict,
19+
Iterable,
20+
List,
21+
Optional,
22+
Tuple,
23+
Union,
24+
)
1625

1726
from prometheus_client import Counter
1827

@@ -829,3 +838,66 @@ async def _check_user_exists(self, user_id: str) -> bool:
829838
if unknown_user:
830839
return await self.query_user_exists(user_id)
831840
return True
841+
842+
async def claim_e2e_one_time_keys(
843+
self, query: Iterable[Tuple[str, str, str]]
844+
) -> Tuple[
845+
Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]], List[Tuple[str, str, str]]
846+
]:
847+
"""Claim one time keys from application services.
848+
849+
Args:
850+
query: An iterable of tuples of (user ID, device ID, algorithm).
851+
852+
Returns:
853+
A tuple of:
854+
An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes.
855+
856+
A copy of the input which has not been fulfilled (either because
857+
they are not appservice users or the appservice does not support
858+
providing OTKs).
859+
"""
860+
services = self.store.get_app_services()
861+
862+
# Partition the users by appservice.
863+
query_by_appservice: Dict[str, List[Tuple[str, str, str]]] = {}
864+
missing = []
865+
for user_id, device, algorithm in query:
866+
if not self.store.get_if_app_services_interested_in_user(user_id):
867+
missing.append((user_id, device, algorithm))
868+
continue
869+
870+
# Find the associated appservice.
871+
for service in services:
872+
if service.is_exclusive_user(user_id):
873+
query_by_appservice.setdefault(service.id, []).append(
874+
(user_id, device, algorithm)
875+
)
876+
continue
877+
878+
# Query each service in parallel.
879+
results = await make_deferred_yieldable(
880+
defer.DeferredList(
881+
[
882+
run_in_background(
883+
self.appservice_api.claim_client_keys,
884+
# We know this must be an app service.
885+
self.store.get_app_service_by_id(service_id), # type: ignore[arg-type]
886+
service_query,
887+
)
888+
for service_id, service_query in query_by_appservice.items()
889+
],
890+
consumeErrors=True,
891+
)
892+
)
893+
894+
# Patch together the results -- they are all independent (since they
895+
# require exclusive control over the users). They get returned as a list
896+
# and the caller combines them.
897+
claimed_keys: List[Dict[str, Dict[str, Dict[str, JsonDict]]]] = []
898+
for success, result in results:
899+
if success:
900+
claimed_keys.append(result[0])
901+
missing.extend(result[1])
902+
903+
return claimed_keys, missing

synapse/handlers/e2e_keys.py

+49-8
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
16-
1716
import logging
1817
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Tuple
1918

@@ -53,6 +52,7 @@ def __init__(self, hs: "HomeServer"):
5352
self.store = hs.get_datastores().main
5453
self.federation = hs.get_federation_client()
5554
self.device_handler = hs.get_device_handler()
55+
self._appservice_handler = hs.get_application_service_handler()
5656
self.is_mine = hs.is_mine
5757
self.clock = hs.get_clock()
5858

@@ -88,6 +88,10 @@ def __init__(self, hs: "HomeServer"):
8888
max_count=10,
8989
)
9090

91+
self._query_appservices_for_otks = (
92+
hs.config.experimental.msc3983_appservice_otk_claims
93+
)
94+
9195
@trace
9296
@cancellable
9397
async def query_devices(
@@ -542,6 +546,42 @@ async def on_federation_query_client_keys(
542546

543547
return ret
544548

549+
async def claim_local_one_time_keys(
550+
self, local_query: List[Tuple[str, str, str]]
551+
) -> Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]]:
552+
"""Claim one time keys for local users.
553+
554+
1. Attempt to claim OTKs from the database.
555+
2. Ask application services if they provide OTKs.
556+
3. Attempt to fetch fallback keys from the database.
557+
558+
Args:
559+
local_query: An iterable of tuples of (user ID, device ID, algorithm).
560+
561+
Returns:
562+
An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes.
563+
"""
564+
565+
otk_results, not_found = await self.store.claim_e2e_one_time_keys(local_query)
566+
567+
# If the application services have not provided any keys via the C-S
568+
# API, query it directly for one-time keys.
569+
if self._query_appservices_for_otks:
570+
(
571+
appservice_results,
572+
not_found,
573+
) = await self._appservice_handler.claim_e2e_one_time_keys(not_found)
574+
else:
575+
appservice_results = []
576+
577+
# For each user that does not have a one-time keys available, see if
578+
# there is a fallback key.
579+
fallback_results = await self.store.claim_e2e_fallback_keys(not_found)
580+
581+
# Return the results in order, each item from the input query should
582+
# only appear once in the combined list.
583+
return (otk_results, *appservice_results, fallback_results)
584+
545585
@trace
546586
async def claim_one_time_keys(
547587
self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: Optional[int]
@@ -561,17 +601,18 @@ async def claim_one_time_keys(
561601
set_tag("local_key_query", str(local_query))
562602
set_tag("remote_key_query", str(remote_queries))
563603

564-
results = await self.store.claim_e2e_one_time_keys(local_query)
604+
results = await self.claim_local_one_time_keys(local_query)
565605

566606
# A map of user ID -> device ID -> key ID -> key.
567607
json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
608+
for result in results:
609+
for user_id, device_keys in result.items():
610+
for device_id, keys in device_keys.items():
611+
for key_id, key in keys.items():
612+
json_result.setdefault(user_id, {})[device_id] = {key_id: key}
613+
614+
# Remote failures.
568615
failures: Dict[str, JsonDict] = {}
569-
for user_id, device_keys in results.items():
570-
for device_id, keys in device_keys.items():
571-
for key_id, json_str in keys.items():
572-
json_result.setdefault(user_id, {})[device_id] = {
573-
key_id: json_decoder.decode(json_str)
574-
}
575616

576617
@trace
577618
async def claim_client_keys(destination: str) -> None:

synapse/storage/databases/main/end_to_end_keys.py

+27-9
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
from synapse.storage.engines import PostgresEngine
5252
from synapse.storage.util.id_generators import StreamIdGenerator
5353
from synapse.types import JsonDict
54-
from synapse.util import json_encoder
54+
from synapse.util import json_decoder, json_encoder
5555
from synapse.util.caches.descriptors import cached, cachedList
5656
from synapse.util.cancellation import cancellable
5757
from synapse.util.iterutils import batch_iter
@@ -1028,14 +1028,17 @@ def get_device_stream_token(self) -> int:
10281028

10291029
async def claim_e2e_one_time_keys(
10301030
self, query_list: Iterable[Tuple[str, str, str]]
1031-
) -> Dict[str, Dict[str, Dict[str, str]]]:
1031+
) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]:
10321032
"""Take a list of one time keys out of the database.
10331033
10341034
Args:
10351035
query_list: An iterable of tuples of (user ID, device ID, algorithm).
10361036
10371037
Returns:
1038-
A map of user ID -> a map device ID -> a map of key ID -> JSON bytes.
1038+
A tuple pf:
1039+
A map of user ID -> a map device ID -> a map of key ID -> JSON.
1040+
1041+
A copy of the input which has not been fulfilled.
10391042
"""
10401043

10411044
@trace
@@ -1115,7 +1118,8 @@ def _claim_e2e_one_time_key_returning(
11151118
key_id, key_json = otk_row
11161119
return f"{algorithm}:{key_id}", key_json
11171120

1118-
results: Dict[str, Dict[str, Dict[str, str]]] = {}
1121+
results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
1122+
missing: List[Tuple[str, str, str]] = []
11191123
for user_id, device_id, algorithm in query_list:
11201124
if self.database_engine.supports_returning:
11211125
# If we support RETURNING clause we can use a single query that
@@ -1138,11 +1142,25 @@ def _claim_e2e_one_time_key_returning(
11381142
device_results = results.setdefault(user_id, {}).setdefault(
11391143
device_id, {}
11401144
)
1141-
device_results[claim_row[0]] = claim_row[1]
1142-
continue
1145+
device_results[claim_row[0]] = json_decoder.decode(claim_row[1])
1146+
else:
1147+
missing.append((user_id, device_id, algorithm))
1148+
1149+
return results, missing
1150+
1151+
async def claim_e2e_fallback_keys(
1152+
self, query_list: Iterable[Tuple[str, str, str]]
1153+
) -> Dict[str, Dict[str, Dict[str, JsonDict]]]:
1154+
"""Take a list of fallback keys out of the database.
11431155
1144-
# No one-time key available, so see if there's a fallback
1145-
# key
1156+
Args:
1157+
query_list: An iterable of tuples of (user ID, device ID, algorithm).
1158+
1159+
Returns:
1160+
A map of user ID -> a map device ID -> a map of key ID -> JSON.
1161+
"""
1162+
results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
1163+
for user_id, device_id, algorithm in query_list:
11461164
row = await self.db_pool.simple_select_one(
11471165
table="e2e_fallback_keys_json",
11481166
keyvalues={
@@ -1179,7 +1197,7 @@ def _claim_e2e_one_time_key_returning(
11791197
)
11801198

11811199
device_results = results.setdefault(user_id, {}).setdefault(device_id, {})
1182-
device_results[f"{algorithm}:{key_id}"] = key_json
1200+
device_results[f"{algorithm}:{key_id}"] = json_decoder.decode(key_json)
11831201

11841202
return results
11851203

0 commit comments

Comments
 (0)