Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit c69d78f

Browse files
clokepH-Shay
authored andcommitted
Add a type hint for get_device_handler() and fix incorrect types. (#14055)
This was the last untyped handler from the HomeServer object. Since it was being treated as Any (and thus unchecked) it was being used incorrectly in a few places.
1 parent 56f8fea commit c69d78f

File tree

16 files changed

+185
-77
lines changed

16 files changed

+185
-77
lines changed

Diff for: changelog.d/14055.misc

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add missing type hints to `HomeServer`.

Diff for: synapse/handlers/deactivate_account.py

+4
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from typing import TYPE_CHECKING, Optional
1717

1818
from synapse.api.errors import SynapseError
19+
from synapse.handlers.device import DeviceHandler
1920
from synapse.metrics.background_process_metrics import run_as_background_process
2021
from synapse.types import Codes, Requester, UserID, create_requester
2122

@@ -76,6 +77,9 @@ async def deactivate_account(
7677
True if identity server supports removing threepids, otherwise False.
7778
"""
7879

80+
# This can only be called on the main process.
81+
assert isinstance(self._device_handler, DeviceHandler)
82+
7983
# Check if this user can be deactivated
8084
if not await self._third_party_rules.check_can_deactivate_user(
8185
user_id, by_admin

Diff for: synapse/handlers/device.py

+50-15
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@
6565

6666

6767
class DeviceWorkerHandler:
68+
device_list_updater: "DeviceListWorkerUpdater"
69+
6870
def __init__(self, hs: "HomeServer"):
6971
self.clock = hs.get_clock()
7072
self.hs = hs
@@ -76,6 +78,8 @@ def __init__(self, hs: "HomeServer"):
7678
self.server_name = hs.hostname
7779
self._msc3852_enabled = hs.config.experimental.msc3852_enabled
7880

81+
self.device_list_updater = DeviceListWorkerUpdater(hs)
82+
7983
@trace
8084
async def get_devices_by_user(self, user_id: str) -> List[JsonDict]:
8185
"""
@@ -99,6 +103,19 @@ async def get_devices_by_user(self, user_id: str) -> List[JsonDict]:
99103
log_kv(device_map)
100104
return devices
101105

106+
async def get_dehydrated_device(
107+
self, user_id: str
108+
) -> Optional[Tuple[str, JsonDict]]:
109+
"""Retrieve the information for a dehydrated device.
110+
111+
Args:
112+
user_id: the user whose dehydrated device we are looking for
113+
Returns:
114+
a tuple whose first item is the device ID, and the second item is
115+
the dehydrated device information
116+
"""
117+
return await self.store.get_dehydrated_device(user_id)
118+
102119
@trace
103120
async def get_device(self, user_id: str, device_id: str) -> JsonDict:
104121
"""Retrieve the given device
@@ -127,7 +144,7 @@ async def get_device(self, user_id: str, device_id: str) -> JsonDict:
127144
@cancellable
128145
async def get_device_changes_in_shared_rooms(
129146
self, user_id: str, room_ids: Collection[str], from_token: StreamToken
130-
) -> Collection[str]:
147+
) -> Set[str]:
131148
"""Get the set of users whose devices have changed who share a room with
132149
the given user.
133150
"""
@@ -320,6 +337,8 @@ async def handle_room_un_partial_stated(self, room_id: str) -> None:
320337

321338

322339
class DeviceHandler(DeviceWorkerHandler):
340+
device_list_updater: "DeviceListUpdater"
341+
323342
def __init__(self, hs: "HomeServer"):
324343
super().__init__(hs)
325344

@@ -606,19 +625,6 @@ async def store_dehydrated_device(
606625
await self.delete_devices(user_id, [old_device_id])
607626
return device_id
608627

609-
async def get_dehydrated_device(
610-
self, user_id: str
611-
) -> Optional[Tuple[str, JsonDict]]:
612-
"""Retrieve the information for a dehydrated device.
613-
614-
Args:
615-
user_id: the user whose dehydrated device we are looking for
616-
Returns:
617-
a tuple whose first item is the device ID, and the second item is
618-
the dehydrated device information
619-
"""
620-
return await self.store.get_dehydrated_device(user_id)
621-
622628
async def rehydrate_device(
623629
self, user_id: str, access_token: str, device_id: str
624630
) -> dict:
@@ -882,7 +888,36 @@ def _update_device_from_client_ips(
882888
)
883889

884890

885-
class DeviceListUpdater:
891+
class DeviceListWorkerUpdater:
892+
"Handles incoming device list updates from federation and contacts the main process over replication"
893+
894+
def __init__(self, hs: "HomeServer"):
895+
from synapse.replication.http.devices import (
896+
ReplicationUserDevicesResyncRestServlet,
897+
)
898+
899+
self._user_device_resync_client = (
900+
ReplicationUserDevicesResyncRestServlet.make_client(hs)
901+
)
902+
903+
async def user_device_resync(
904+
self, user_id: str, mark_failed_as_stale: bool = True
905+
) -> Optional[JsonDict]:
906+
"""Fetches all devices for a user and updates the device cache with them.
907+
908+
Args:
909+
user_id: The user's id whose device_list will be updated.
910+
mark_failed_as_stale: Whether to mark the user's device list as stale
911+
if the attempt to resync failed.
912+
Returns:
913+
A dict with device info as under the "devices" in the result of this
914+
request:
915+
https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid
916+
"""
917+
return await self._user_device_resync_client(user_id=user_id)
918+
919+
920+
class DeviceListUpdater(DeviceListWorkerUpdater):
886921
"Handles incoming device list updates from federation and updates the DB"
887922

888923
def __init__(self, hs: "HomeServer", device_handler: DeviceHandler):

Diff for: synapse/handlers/e2e_keys.py

+32-29
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@
2727

2828
from synapse.api.constants import EduTypes
2929
from synapse.api.errors import CodeMessageException, Codes, NotFoundError, SynapseError
30+
from synapse.handlers.device import DeviceHandler
3031
from synapse.logging.context import make_deferred_yieldable, run_in_background
3132
from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace
32-
from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
3333
from synapse.types import (
3434
JsonDict,
3535
UserID,
@@ -56,27 +56,23 @@ def __init__(self, hs: "HomeServer"):
5656
self.is_mine = hs.is_mine
5757
self.clock = hs.get_clock()
5858

59-
self._edu_updater = SigningKeyEduUpdater(hs, self)
60-
6159
federation_registry = hs.get_federation_registry()
6260

63-
self._is_master = hs.config.worker.worker_app is None
64-
if not self._is_master:
65-
self._user_device_resync_client = (
66-
ReplicationUserDevicesResyncRestServlet.make_client(hs)
67-
)
68-
else:
61+
is_master = hs.config.worker.worker_app is None
62+
if is_master:
63+
edu_updater = SigningKeyEduUpdater(hs)
64+
6965
# Only register this edu handler on master as it requires writing
7066
# device updates to the db
7167
federation_registry.register_edu_handler(
7268
EduTypes.SIGNING_KEY_UPDATE,
73-
self._edu_updater.incoming_signing_key_update,
69+
edu_updater.incoming_signing_key_update,
7470
)
7571
# also handle the unstable version
7672
# FIXME: remove this when enough servers have upgraded
7773
federation_registry.register_edu_handler(
7874
EduTypes.UNSTABLE_SIGNING_KEY_UPDATE,
79-
self._edu_updater.incoming_signing_key_update,
75+
edu_updater.incoming_signing_key_update,
8076
)
8177

8278
# doesn't really work as part of the generic query API, because the
@@ -319,14 +315,13 @@ async def _query_devices_for_destination(
319315
# probably be tracking their device lists. However, we haven't
320316
# done an initial sync on the device list so we do it now.
321317
try:
322-
if self._is_master:
323-
resync_results = await self.device_handler.device_list_updater.user_device_resync(
318+
resync_results = (
319+
await self.device_handler.device_list_updater.user_device_resync(
324320
user_id
325321
)
326-
else:
327-
resync_results = await self._user_device_resync_client(
328-
user_id=user_id
329-
)
322+
)
323+
if resync_results is None:
324+
raise ValueError("Device resync failed")
330325

331326
# Add the device keys to the results.
332327
user_devices = resync_results["devices"]
@@ -605,6 +600,8 @@ async def claim_client_keys(destination: str) -> None:
605600
async def upload_keys_for_user(
606601
self, user_id: str, device_id: str, keys: JsonDict
607602
) -> JsonDict:
603+
# This can only be called from the main process.
604+
assert isinstance(self.device_handler, DeviceHandler)
608605

609606
time_now = self.clock.time_msec()
610607

@@ -732,6 +729,8 @@ async def upload_signing_keys_for_user(
732729
user_id: the user uploading the keys
733730
keys: the signing keys
734731
"""
732+
# This can only be called from the main process.
733+
assert isinstance(self.device_handler, DeviceHandler)
735734

736735
# if a master key is uploaded, then check it. Otherwise, load the
737736
# stored master key, to check signatures on other keys
@@ -823,6 +822,9 @@ async def upload_signatures_for_device_keys(
823822
Raises:
824823
SynapseError: if the signatures dict is not valid.
825824
"""
825+
# This can only be called from the main process.
826+
assert isinstance(self.device_handler, DeviceHandler)
827+
826828
failures = {}
827829

828830
# signatures to be stored. Each item will be a SignatureListItem
@@ -1200,6 +1202,9 @@ async def _retrieve_cross_signing_keys_for_remote_user(
12001202
A tuple of the retrieved key content, the key's ID and the matching VerifyKey.
12011203
If the key cannot be retrieved, all values in the tuple will instead be None.
12021204
"""
1205+
# This can only be called from the main process.
1206+
assert isinstance(self.device_handler, DeviceHandler)
1207+
12031208
try:
12041209
remote_result = await self.federation.query_user_devices(
12051210
user.domain, user.to_string()
@@ -1396,11 +1401,14 @@ class SignatureListItem:
13961401
class SigningKeyEduUpdater:
13971402
"""Handles incoming signing key updates from federation and updates the DB"""
13981403

1399-
def __init__(self, hs: "HomeServer", e2e_keys_handler: E2eKeysHandler):
1404+
def __init__(self, hs: "HomeServer"):
14001405
self.store = hs.get_datastores().main
14011406
self.federation = hs.get_federation_client()
14021407
self.clock = hs.get_clock()
1403-
self.e2e_keys_handler = e2e_keys_handler
1408+
1409+
device_handler = hs.get_device_handler()
1410+
assert isinstance(device_handler, DeviceHandler)
1411+
self._device_handler = device_handler
14041412

14051413
self._remote_edu_linearizer = Linearizer(name="remote_signing_key")
14061414

@@ -1445,9 +1453,6 @@ async def _handle_signing_key_updates(self, user_id: str) -> None:
14451453
user_id: the user whose updates we are processing
14461454
"""
14471455

1448-
device_handler = self.e2e_keys_handler.device_handler
1449-
device_list_updater = device_handler.device_list_updater
1450-
14511456
async with self._remote_edu_linearizer.queue(user_id):
14521457
pending_updates = self._pending_updates.pop(user_id, [])
14531458
if not pending_updates:
@@ -1459,13 +1464,11 @@ async def _handle_signing_key_updates(self, user_id: str) -> None:
14591464
logger.info("pending updates: %r", pending_updates)
14601465

14611466
for master_key, self_signing_key in pending_updates:
1462-
new_device_ids = (
1463-
await device_list_updater.process_cross_signing_key_update(
1464-
user_id,
1465-
master_key,
1466-
self_signing_key,
1467-
)
1467+
new_device_ids = await self._device_handler.device_list_updater.process_cross_signing_key_update(
1468+
user_id,
1469+
master_key,
1470+
self_signing_key,
14681471
)
14691472
device_ids = device_ids + new_device_ids
14701473

1471-
await device_handler.notify_device_update(user_id, device_ids)
1474+
await self._device_handler.notify_device_update(user_id, device_ids)

Diff for: synapse/handlers/register.py

+4
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
)
3939
from synapse.appservice import ApplicationService
4040
from synapse.config.server import is_threepid_reserved
41+
from synapse.handlers.device import DeviceHandler
4142
from synapse.http.servlet import assert_params_in_dict
4243
from synapse.replication.http.login import RegisterDeviceReplicationServlet
4344
from synapse.replication.http.register import (
@@ -841,6 +842,9 @@ class and RegisterDeviceReplicationServlet.
841842
refresh_token = None
842843
refresh_token_id = None
843844

845+
# This can only run on the main process.
846+
assert isinstance(self.device_handler, DeviceHandler)
847+
844848
registered_device_id = await self.device_handler.check_device_registered(
845849
user_id,
846850
device_id,

Diff for: synapse/handlers/set_password.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from typing import TYPE_CHECKING, Optional
1616

1717
from synapse.api.errors import Codes, StoreError, SynapseError
18+
from synapse.handlers.device import DeviceHandler
1819
from synapse.types import Requester
1920

2021
if TYPE_CHECKING:
@@ -29,7 +30,10 @@ class SetPasswordHandler:
2930
def __init__(self, hs: "HomeServer"):
3031
self.store = hs.get_datastores().main
3132
self._auth_handler = hs.get_auth_handler()
32-
self._device_handler = hs.get_device_handler()
33+
# This can only be instantiated on the main process.
34+
device_handler = hs.get_device_handler()
35+
assert isinstance(device_handler, DeviceHandler)
36+
self._device_handler = device_handler
3337

3438
async def set_password(
3539
self,

Diff for: synapse/handlers/sso.py

+9
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from synapse.api.constants import LoginType
3838
from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError
3939
from synapse.config.sso import SsoAttributeRequirement
40+
from synapse.handlers.device import DeviceHandler
4041
from synapse.handlers.register import init_counters_for_auth_provider
4142
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
4243
from synapse.http import get_request_user_agent
@@ -1035,13 +1036,21 @@ async def revoke_sessions_for_provider_session_id(
10351036
) -> None:
10361037
"""Revoke any devices and in-flight logins tied to a provider session.
10371038
1039+
Can only be called from the main process.
1040+
10381041
Args:
10391042
auth_provider_id: A unique identifier for this SSO provider, e.g.
10401043
"oidc" or "saml".
10411044
auth_provider_session_id: The session ID from the provider to logout
10421045
expected_user_id: The user we're expecting to logout. If set, it will ignore
10431046
sessions belonging to other users and log an error.
10441047
"""
1048+
1049+
# It is expected that this is the main process.
1050+
assert isinstance(
1051+
self._device_handler, DeviceHandler
1052+
), "revoking SSO sessions can only be called on the main process"
1053+
10451054
# Invalidate any running user-mapping sessions
10461055
to_delete = []
10471056
for session_id, session in self._username_mapping_sessions.items():

Diff for: synapse/module_api/__init__.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@
8686
ON_LOGGED_OUT_CALLBACK,
8787
AuthHandler,
8888
)
89+
from synapse.handlers.device import DeviceHandler
8990
from synapse.handlers.push_rules import RuleSpec, check_actions
9091
from synapse.http.client import SimpleHttpClient
9192
from synapse.http.server import (
@@ -207,6 +208,7 @@ def __init__(self, hs: "HomeServer", auth_handler: AuthHandler) -> None:
207208
self._registration_handler = hs.get_registration_handler()
208209
self._send_email_handler = hs.get_send_email_handler()
209210
self._push_rules_handler = hs.get_push_rules_handler()
211+
self._device_handler = hs.get_device_handler()
210212
self.custom_template_dir = hs.config.server.custom_template_directory
211213

212214
try:
@@ -784,6 +786,8 @@ def invalidate_access_token(
784786
) -> Generator["defer.Deferred[Any]", Any, None]:
785787
"""Invalidate an access token for a user
786788
789+
Can only be called from the main process.
790+
787791
Added in Synapse v0.25.0.
788792
789793
Args:
@@ -796,6 +800,10 @@ def invalidate_access_token(
796800
Raises:
797801
synapse.api.errors.AuthError: the access token is invalid
798802
"""
803+
assert isinstance(
804+
self._device_handler, DeviceHandler
805+
), "invalidate_access_token can only be called on the main process"
806+
799807
# see if the access token corresponds to a device
800808
user_info = yield defer.ensureDeferred(
801809
self._auth.get_user_by_access_token(access_token)
@@ -805,7 +813,7 @@ def invalidate_access_token(
805813
if device_id:
806814
# delete the device, which will also delete its access tokens
807815
yield defer.ensureDeferred(
808-
self._hs.get_device_handler().delete_devices(user_id, [device_id])
816+
self._device_handler.delete_devices(user_id, [device_id])
809817
)
810818
else:
811819
# no associated device. Just delete the access token.

0 commit comments

Comments
 (0)