@@ -564,7 +564,7 @@ async def on_federation_query_client_keys(
564
564
565
565
async def claim_local_one_time_keys (
566
566
self ,
567
- local_query : List [Tuple [str , str , str ]],
567
+ local_query : List [Tuple [str , str , str , int ]],
568
568
always_include_fallback_keys : bool ,
569
569
) -> Iterable [Dict [str , Dict [str , Dict [str , JsonDict ]]]]:
570
570
"""Claim one time keys for local users.
@@ -581,6 +581,12 @@ async def claim_local_one_time_keys(
581
581
An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes.
582
582
"""
583
583
584
+ # Cap the number of OTKs that can be claimed at once to avoid abuse.
585
+ local_query = [
586
+ (user_id , device_id , algorithm , min (count , 5 ))
587
+ for user_id , device_id , algorithm , count in local_query
588
+ ]
589
+
584
590
otk_results , not_found = await self .store .claim_e2e_one_time_keys (local_query )
585
591
586
592
# If the application services have not provided any keys via the C-S
@@ -607,7 +613,7 @@ async def claim_local_one_time_keys(
607
613
# from the appservice for that user ID / device ID. If it is found,
608
614
# check if any of the keys match the requested algorithm & are a
609
615
# fallback key.
610
- for user_id , device_id , algorithm in local_query :
616
+ for user_id , device_id , algorithm , _count in local_query :
611
617
# Check if the appservice responded for this query.
612
618
as_result = appservice_results .get (user_id , {}).get (device_id , {})
613
619
found_otk = False
@@ -630,13 +636,17 @@ async def claim_local_one_time_keys(
630
636
.get (device_id , {})
631
637
.keys ()
632
638
)
639
+ # Note that it doesn't make sense to request more than 1 fallback key
640
+ # per (user_id, device_id, algorithm).
633
641
fallback_query .append ((user_id , device_id , algorithm , mark_as_used ))
634
642
635
643
else :
636
644
# All fallback keys get marked as used.
637
645
fallback_query = [
646
+ # Note that it doesn't make sense to request more than 1 fallback key
647
+ # per (user_id, device_id, algorithm).
638
648
(user_id , device_id , algorithm , True )
639
- for user_id , device_id , algorithm in not_found
649
+ for user_id , device_id , algorithm , count in not_found
640
650
]
641
651
642
652
# For each user that does not have a one-time keys available, see if
@@ -650,18 +660,19 @@ async def claim_local_one_time_keys(
650
660
@trace
651
661
async def claim_one_time_keys (
652
662
self ,
653
- query : Dict [str , Dict [str , Dict [str , str ]]],
663
+ query : Dict [str , Dict [str , Dict [str , int ]]],
654
664
timeout : Optional [int ],
655
665
always_include_fallback_keys : bool ,
656
666
) -> JsonDict :
657
- local_query : List [Tuple [str , str , str ]] = []
658
- remote_queries : Dict [str , Dict [str , Dict [str , str ]]] = {}
667
+ local_query : List [Tuple [str , str , str , int ]] = []
668
+ remote_queries : Dict [str , Dict [str , Dict [str , Dict [ str , int ] ]]] = {}
659
669
660
- for user_id , one_time_keys in query .get ( "one_time_keys" , {}). items ():
670
+ for user_id , one_time_keys in query .items ():
661
671
# we use UserID.from_string to catch invalid user ids
662
672
if self .is_mine (UserID .from_string (user_id )):
663
- for device_id , algorithm in one_time_keys .items ():
664
- local_query .append ((user_id , device_id , algorithm ))
673
+ for device_id , algorithms in one_time_keys .items ():
674
+ for algorithm , count in algorithms .items ():
675
+ local_query .append ((user_id , device_id , algorithm , count ))
665
676
else :
666
677
domain = get_domain_from_id (user_id )
667
678
remote_queries .setdefault (domain , {})[user_id ] = one_time_keys
@@ -692,7 +703,7 @@ async def claim_client_keys(destination: str) -> None:
692
703
device_keys = remote_queries [destination ]
693
704
try :
694
705
remote_result = await self .federation .claim_client_keys (
695
- destination , { "one_time_keys" : device_keys } , timeout = timeout
706
+ destination , device_keys , timeout = timeout
696
707
)
697
708
for user_id , keys in remote_result ["one_time_keys" ].items ():
698
709
if user_id in device_keys :
0 commit comments