13
13
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
14
# See the License for the specific language governing permissions and
15
15
# limitations under the License.
16
-
17
16
import logging
18
17
from typing import TYPE_CHECKING , Any , Dict , Iterable , List , Mapping , Optional , Tuple
19
18
@@ -53,6 +52,7 @@ def __init__(self, hs: "HomeServer"):
53
52
self .store = hs .get_datastores ().main
54
53
self .federation = hs .get_federation_client ()
55
54
self .device_handler = hs .get_device_handler ()
55
+ self ._appservice_handler = hs .get_application_service_handler ()
56
56
self .is_mine = hs .is_mine
57
57
self .clock = hs .get_clock ()
58
58
@@ -88,6 +88,10 @@ def __init__(self, hs: "HomeServer"):
88
88
max_count = 10 ,
89
89
)
90
90
91
+ self ._query_appservices_for_otks = (
92
+ hs .config .experimental .msc3983_appservice_otk_claims
93
+ )
94
+
91
95
@trace
92
96
@cancellable
93
97
async def query_devices (
@@ -542,6 +546,42 @@ async def on_federation_query_client_keys(
542
546
543
547
return ret
544
548
549
+ async def claim_local_one_time_keys (
550
+ self , local_query : List [Tuple [str , str , str ]]
551
+ ) -> Iterable [Dict [str , Dict [str , Dict [str , JsonDict ]]]]:
552
+ """Claim one time keys for local users.
553
+
554
+ 1. Attempt to claim OTKs from the database.
555
+ 2. Ask application services if they provide OTKs.
556
+ 3. Attempt to fetch fallback keys from the database.
557
+
558
+ Args:
559
+ local_query: An iterable of tuples of (user ID, device ID, algorithm).
560
+
561
+ Returns:
562
+ An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes.
563
+ """
564
+
565
+ otk_results , not_found = await self .store .claim_e2e_one_time_keys (local_query )
566
+
567
+ # If the application services have not provided any keys via the C-S
568
+ # API, query it directly for one-time keys.
569
+ if self ._query_appservices_for_otks :
570
+ (
571
+ appservice_results ,
572
+ not_found ,
573
+ ) = await self ._appservice_handler .claim_e2e_one_time_keys (not_found )
574
+ else :
575
+ appservice_results = []
576
+
577
+ # For each user that does not have a one-time keys available, see if
578
+ # there is a fallback key.
579
+ fallback_results = await self .store .claim_e2e_fallback_keys (not_found )
580
+
581
+ # Return the results in order, each item from the input query should
582
+ # only appear once in the combined list.
583
+ return (otk_results , * appservice_results , fallback_results )
584
+
545
585
@trace
546
586
async def claim_one_time_keys (
547
587
self , query : Dict [str , Dict [str , Dict [str , str ]]], timeout : Optional [int ]
@@ -561,17 +601,18 @@ async def claim_one_time_keys(
561
601
set_tag ("local_key_query" , str (local_query ))
562
602
set_tag ("remote_key_query" , str (remote_queries ))
563
603
564
- results = await self .store . claim_e2e_one_time_keys (local_query )
604
+ results = await self .claim_local_one_time_keys (local_query )
565
605
566
606
# A map of user ID -> device ID -> key ID -> key.
567
607
json_result : Dict [str , Dict [str , Dict [str , JsonDict ]]] = {}
608
+ for result in results :
609
+ for user_id , device_keys in result .items ():
610
+ for device_id , keys in device_keys .items ():
611
+ for key_id , key in keys .items ():
612
+ json_result .setdefault (user_id , {})[device_id ] = {key_id : key }
613
+
614
+ # Remote failures.
568
615
failures : Dict [str , JsonDict ] = {}
569
- for user_id , device_keys in results .items ():
570
- for device_id , keys in device_keys .items ():
571
- for key_id , json_str in keys .items ():
572
- json_result .setdefault (user_id , {})[device_id ] = {
573
- key_id : json_decoder .decode (json_str )
574
- }
575
616
576
617
@trace
577
618
async def claim_client_keys (destination : str ) -> None :
0 commit comments