Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 57aeeb3

Browse files
authored
Add support for claiming multiple OTKs at once. (#15468)
MSC3983 provides a way to request multiple OTKs at once from appservices, this extends this concept to the Client-Server API. Note that this will likely be spit out into a separate MSC, but is currently part of MSC3983.
1 parent 6efa674 commit 57aeeb3

File tree

12 files changed

+271
-98
lines changed

12 files changed

+271
-98
lines changed

changelog.d/15468.misc

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Support claiming more than one OTK at a time.

synapse/appservice/api.py

+22-9
Original file line numberDiff line numberDiff line change
@@ -442,8 +442,10 @@ async def push_bulk(
442442
return False
443443

444444
async def claim_client_keys(
445-
self, service: "ApplicationService", query: List[Tuple[str, str, str]]
446-
) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]:
445+
self, service: "ApplicationService", query: List[Tuple[str, str, str, int]]
446+
) -> Tuple[
447+
Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str, int]]
448+
]:
447449
"""Claim one time keys from an application service.
448450
449451
Note that any error (including a timeout) is treated as the application
@@ -469,8 +471,10 @@ async def claim_client_keys(
469471

470472
# Create the expected payload shape.
471473
body: Dict[str, Dict[str, List[str]]] = {}
472-
for user_id, device, algorithm in query:
473-
body.setdefault(user_id, {}).setdefault(device, []).append(algorithm)
474+
for user_id, device, algorithm, count in query:
475+
body.setdefault(user_id, {}).setdefault(device, []).extend(
476+
[algorithm] * count
477+
)
474478

475479
uri = f"{service.url}/_matrix/app/unstable/org.matrix.msc3983/keys/claim"
476480
try:
@@ -493,11 +497,20 @@ async def claim_client_keys(
493497
# or if some are still missing.
494498
#
495499
# TODO This places a lot of faith in the response shape being correct.
496-
missing = [
497-
(user_id, device, algorithm)
498-
for user_id, device, algorithm in query
499-
if algorithm not in response.get(user_id, {}).get(device, [])
500-
]
500+
missing = []
501+
for user_id, device, algorithm, count in query:
502+
# Count the number of keys in the response for this algorithm by
503+
# checking which key IDs start with the algorithm. This uses that
504+
# True == 1 in Python to generate a count.
505+
response_count = sum(
506+
key_id.startswith(f"{algorithm}:")
507+
for key_id in response.get(user_id, {}).get(device, {})
508+
)
509+
count -= response_count
510+
# If the appservice responds with fewer keys than requested, then
511+
# consider the request unfulfilled.
512+
if count > 0:
513+
missing.append((user_id, device, algorithm, count))
501514

502515
return response, missing
503516

synapse/federation/federation_client.py

+48-1
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,10 @@ async def query_user_devices(
235235
)
236236

237237
async def claim_client_keys(
238-
self, destination: str, content: JsonDict, timeout: Optional[int]
238+
self,
239+
destination: str,
240+
query: Dict[str, Dict[str, Dict[str, int]]],
241+
timeout: Optional[int],
239242
) -> JsonDict:
240243
"""Claims one-time keys for a device hosted on a remote server.
241244
@@ -247,6 +250,50 @@ async def claim_client_keys(
247250
The JSON object from the response
248251
"""
249252
sent_queries_counter.labels("client_one_time_keys").inc()
253+
254+
# Convert the query with counts into a stable and unstable query and check
255+
# if attempting to claim more than 1 OTK.
256+
content: Dict[str, Dict[str, str]] = {}
257+
unstable_content: Dict[str, Dict[str, List[str]]] = {}
258+
use_unstable = False
259+
for user_id, one_time_keys in query.items():
260+
for device_id, algorithms in one_time_keys.items():
261+
if any(count > 1 for count in algorithms.values()):
262+
use_unstable = True
263+
if algorithms:
264+
# For the stable query, choose only the first algorithm.
265+
content.setdefault(user_id, {})[device_id] = next(iter(algorithms))
266+
# For the unstable query, repeat each algorithm by count, then
267+
# splat those into chain to get a flattened list of all algorithms.
268+
#
269+
# Converts from {"algo1": 2, "algo2": 2} to ["algo1", "algo1", "algo2"].
270+
unstable_content.setdefault(user_id, {})[device_id] = list(
271+
itertools.chain(
272+
*(
273+
itertools.repeat(algorithm, count)
274+
for algorithm, count in algorithms.items()
275+
)
276+
)
277+
)
278+
279+
if use_unstable:
280+
try:
281+
return await self.transport_layer.claim_client_keys_unstable(
282+
destination, unstable_content, timeout
283+
)
284+
except HttpResponseException as e:
285+
# If an error is received that is due to an unrecognised endpoint,
286+
# fallback to the v1 endpoint. Otherwise, consider it a legitimate error
287+
# and raise.
288+
if not is_unknown_endpoint(e):
289+
raise
290+
291+
logger.debug(
292+
"Couldn't claim client keys with the unstable API, falling back to the v1 API"
293+
)
294+
else:
295+
logger.debug("Skipping unstable claim client keys API")
296+
250297
return await self.transport_layer.claim_client_keys(
251298
destination, content, timeout
252299
)

synapse/federation/federation_server.py

+1-6
Original file line numberDiff line numberDiff line change
@@ -1005,13 +1005,8 @@ async def on_query_user_devices(
10051005

10061006
@trace
10071007
async def on_claim_client_keys(
1008-
self, origin: str, content: JsonDict, always_include_fallback_keys: bool
1008+
self, query: List[Tuple[str, str, str, int]], always_include_fallback_keys: bool
10091009
) -> Dict[str, Any]:
1010-
query = []
1011-
for user_id, device_keys in content.get("one_time_keys", {}).items():
1012-
for device_id, algorithm in device_keys.items():
1013-
query.append((user_id, device_id, algorithm))
1014-
10151010
log_kv({"message": "Claiming one time keys.", "user, device pairs": query})
10161011
results = await self._e2e_keys_handler.claim_local_one_time_keys(
10171012
query, always_include_fallback_keys=always_include_fallback_keys

synapse/federation/transport/client.py

+46-3
Original file line numberDiff line numberDiff line change
@@ -650,10 +650,10 @@ async def claim_client_keys(
650650
651651
Response:
652652
{
653-
"device_keys": {
653+
"one_time_keys": {
654654
"<user_id>": {
655655
"<device_id>": {
656-
"<algorithm>:<key_id>": "<key_base64>"
656+
"<algorithm>:<key_id>": <OTK JSON>
657657
}
658658
}
659659
}
@@ -669,7 +669,50 @@ async def claim_client_keys(
669669
path = _create_v1_path("/user/keys/claim")
670670

671671
return await self.client.post_json(
672-
destination=destination, path=path, data=query_content, timeout=timeout
672+
destination=destination,
673+
path=path,
674+
data={"one_time_keys": query_content},
675+
timeout=timeout,
676+
)
677+
678+
async def claim_client_keys_unstable(
679+
self, destination: str, query_content: JsonDict, timeout: Optional[int]
680+
) -> JsonDict:
681+
"""Claim one-time keys for a list of devices hosted on a remote server.
682+
683+
Request:
684+
{
685+
"one_time_keys": {
686+
"<user_id>": {
687+
"<device_id>": {"<algorithm>": <count>}
688+
}
689+
}
690+
}
691+
692+
Response:
693+
{
694+
"one_time_keys": {
695+
"<user_id>": {
696+
"<device_id>": {
697+
"<algorithm>:<key_id>": <OTK JSON>
698+
}
699+
}
700+
}
701+
}
702+
703+
Args:
704+
destination: The server to query.
705+
query_content: The user ids to query.
706+
Returns:
707+
A dict containing the one-time keys.
708+
"""
709+
path = _create_path(FEDERATION_UNSTABLE_PREFIX, "/user/keys/claim")
710+
711+
return await self.client.post_json(
712+
destination=destination,
713+
path=path,
714+
data={"one_time_keys": query_content},
715+
timeout=timeout,
673716
)
674717

675718
async def get_missing_events(

synapse/federation/transport/server/federation.py

+21-4
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import logging
15+
from collections import Counter
1516
from typing import (
1617
TYPE_CHECKING,
1718
Dict,
@@ -577,16 +578,23 @@ class FederationClientKeysClaimServlet(BaseFederationServerServlet):
577578
async def on_POST(
578579
self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
579580
) -> Tuple[int, JsonDict]:
581+
# Generate a count for each algorithm, which is hard-coded to 1.
582+
key_query: List[Tuple[str, str, str, int]] = []
583+
for user_id, device_keys in content.get("one_time_keys", {}).items():
584+
for device_id, algorithm in device_keys.items():
585+
key_query.append((user_id, device_id, algorithm, 1))
586+
580587
response = await self.handler.on_claim_client_keys(
581-
origin, content, always_include_fallback_keys=False
588+
key_query, always_include_fallback_keys=False
582589
)
583590
return 200, response
584591

585592

586593
class FederationUnstableClientKeysClaimServlet(BaseFederationServerServlet):
587594
"""
588-
Identical to the stable endpoint (FederationClientKeysClaimServlet) except it
589-
always includes fallback keys in the response.
595+
Identical to the stable endpoint (FederationClientKeysClaimServlet) except
596+
it allows for querying for multiple OTKs at once and always includes fallback
597+
keys in the response.
590598
"""
591599

592600
PREFIX = FEDERATION_UNSTABLE_PREFIX
@@ -596,8 +604,16 @@ class FederationUnstableClientKeysClaimServlet(BaseFederationServerServlet):
596604
async def on_POST(
597605
self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
598606
) -> Tuple[int, JsonDict]:
607+
# Generate a count for each algorithm.
608+
key_query: List[Tuple[str, str, str, int]] = []
609+
for user_id, device_keys in content.get("one_time_keys", {}).items():
610+
for device_id, algorithms in device_keys.items():
611+
counts = Counter(algorithms)
612+
for algorithm, count in counts.items():
613+
key_query.append((user_id, device_id, algorithm, count))
614+
599615
response = await self.handler.on_claim_client_keys(
600-
origin, content, always_include_fallback_keys=True
616+
key_query, always_include_fallback_keys=True
601617
)
602618
return 200, response
603619

@@ -805,6 +821,7 @@ async def on_POST(
805821
FederationClientKeysQueryServlet,
806822
FederationUserDevicesQueryServlet,
807823
FederationClientKeysClaimServlet,
824+
FederationUnstableClientKeysClaimServlet,
808825
FederationThirdPartyInviteExchangeServlet,
809826
On3pidBindServlet,
810827
FederationVersionServlet,

synapse/handlers/appservice.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -841,8 +841,10 @@ async def _check_user_exists(self, user_id: str) -> bool:
841841
return True
842842

843843
async def claim_e2e_one_time_keys(
844-
self, query: Iterable[Tuple[str, str, str]]
845-
) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]:
844+
self, query: Iterable[Tuple[str, str, str, int]]
845+
) -> Tuple[
846+
Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str, int]]
847+
]:
846848
"""Claim one time keys from application services.
847849
848850
Users which are exclusively owned by an application service are sent a
@@ -863,18 +865,18 @@ async def claim_e2e_one_time_keys(
863865
services = self.store.get_app_services()
864866

865867
# Partition the users by appservice.
866-
query_by_appservice: Dict[str, List[Tuple[str, str, str]]] = {}
868+
query_by_appservice: Dict[str, List[Tuple[str, str, str, int]]] = {}
867869
missing = []
868-
for user_id, device, algorithm in query:
870+
for user_id, device, algorithm, count in query:
869871
if not self.store.get_if_app_services_interested_in_user(user_id):
870-
missing.append((user_id, device, algorithm))
872+
missing.append((user_id, device, algorithm, count))
871873
continue
872874

873875
# Find the associated appservice.
874876
for service in services:
875877
if service.is_exclusive_user(user_id):
876878
query_by_appservice.setdefault(service.id, []).append(
877-
(user_id, device, algorithm)
879+
(user_id, device, algorithm, count)
878880
)
879881
continue
880882

synapse/handlers/e2e_keys.py

+21-10
Original file line numberDiff line numberDiff line change
@@ -564,7 +564,7 @@ async def on_federation_query_client_keys(
564564

565565
async def claim_local_one_time_keys(
566566
self,
567-
local_query: List[Tuple[str, str, str]],
567+
local_query: List[Tuple[str, str, str, int]],
568568
always_include_fallback_keys: bool,
569569
) -> Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]]:
570570
"""Claim one time keys for local users.
@@ -581,6 +581,12 @@ async def claim_local_one_time_keys(
581581
An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes.
582582
"""
583583

584+
# Cap the number of OTKs that can be claimed at once to avoid abuse.
585+
local_query = [
586+
(user_id, device_id, algorithm, min(count, 5))
587+
for user_id, device_id, algorithm, count in local_query
588+
]
589+
584590
otk_results, not_found = await self.store.claim_e2e_one_time_keys(local_query)
585591

586592
# If the application services have not provided any keys via the C-S
@@ -607,7 +613,7 @@ async def claim_local_one_time_keys(
607613
# from the appservice for that user ID / device ID. If it is found,
608614
# check if any of the keys match the requested algorithm & are a
609615
# fallback key.
610-
for user_id, device_id, algorithm in local_query:
616+
for user_id, device_id, algorithm, _count in local_query:
611617
# Check if the appservice responded for this query.
612618
as_result = appservice_results.get(user_id, {}).get(device_id, {})
613619
found_otk = False
@@ -630,13 +636,17 @@ async def claim_local_one_time_keys(
630636
.get(device_id, {})
631637
.keys()
632638
)
639+
# Note that it doesn't make sense to request more than 1 fallback key
640+
# per (user_id, device_id, algorithm).
633641
fallback_query.append((user_id, device_id, algorithm, mark_as_used))
634642

635643
else:
636644
# All fallback keys get marked as used.
637645
fallback_query = [
646+
# Note that it doesn't make sense to request more than 1 fallback key
647+
# per (user_id, device_id, algorithm).
638648
(user_id, device_id, algorithm, True)
639-
for user_id, device_id, algorithm in not_found
649+
for user_id, device_id, algorithm, count in not_found
640650
]
641651

642652
# For each user that does not have a one-time keys available, see if
@@ -650,18 +660,19 @@ async def claim_local_one_time_keys(
650660
@trace
651661
async def claim_one_time_keys(
652662
self,
653-
query: Dict[str, Dict[str, Dict[str, str]]],
663+
query: Dict[str, Dict[str, Dict[str, int]]],
654664
timeout: Optional[int],
655665
always_include_fallback_keys: bool,
656666
) -> JsonDict:
657-
local_query: List[Tuple[str, str, str]] = []
658-
remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {}
667+
local_query: List[Tuple[str, str, str, int]] = []
668+
remote_queries: Dict[str, Dict[str, Dict[str, Dict[str, int]]]] = {}
659669

660-
for user_id, one_time_keys in query.get("one_time_keys", {}).items():
670+
for user_id, one_time_keys in query.items():
661671
# we use UserID.from_string to catch invalid user ids
662672
if self.is_mine(UserID.from_string(user_id)):
663-
for device_id, algorithm in one_time_keys.items():
664-
local_query.append((user_id, device_id, algorithm))
673+
for device_id, algorithms in one_time_keys.items():
674+
for algorithm, count in algorithms.items():
675+
local_query.append((user_id, device_id, algorithm, count))
665676
else:
666677
domain = get_domain_from_id(user_id)
667678
remote_queries.setdefault(domain, {})[user_id] = one_time_keys
@@ -692,7 +703,7 @@ async def claim_client_keys(destination: str) -> None:
692703
device_keys = remote_queries[destination]
693704
try:
694705
remote_result = await self.federation.claim_client_keys(
695-
destination, {"one_time_keys": device_keys}, timeout=timeout
706+
destination, device_keys, timeout=timeout
696707
)
697708
for user_id, keys in remote_result["one_time_keys"].items():
698709
if user_id in device_keys:

0 commit comments

Comments
 (0)