Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 2519bea

Browse files
authored
Add missing type hints to synapse.appservice (#11360)
1 parent 70ca053 commit 2519bea

File tree

7 files changed

+148
-93
lines changed

7 files changed

+148
-93
lines changed

changelog.d/11360.misc

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add type hints to `synapse.appservice`.

mypy.ini

+3
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,9 @@ disallow_untyped_defs = True
143143
[mypy-synapse.app.*]
144144
disallow_untyped_defs = True
145145

146+
[mypy-synapse.appservice.*]
147+
disallow_untyped_defs = True
148+
146149
[mypy-synapse.config._base]
147150
disallow_untyped_defs = True
148151

synapse/appservice/__init__.py

+61-40
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,14 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
1415
import logging
1516
import re
1617
from enum import Enum
17-
from typing import TYPE_CHECKING, Iterable, List, Match, Optional
18+
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Pattern
19+
20+
import attr
21+
from netaddr import IPSet
1822

1923
from synapse.api.constants import EventTypes
2024
from synapse.events import EventBase
@@ -33,6 +37,13 @@ class ApplicationServiceState(Enum):
3337
UP = "up"
3438

3539

40+
@attr.s(slots=True, frozen=True, auto_attribs=True)
41+
class Namespace:
42+
exclusive: bool
43+
group_id: Optional[str]
44+
regex: Pattern[str]
45+
46+
3647
class ApplicationService:
3748
"""Defines an application service. This definition is mostly what is
3849
provided to the /register AS API.
@@ -50,17 +61,17 @@ class ApplicationService:
5061

5162
def __init__(
5263
self,
53-
token,
54-
hostname,
55-
id,
56-
sender,
57-
url=None,
58-
namespaces=None,
59-
hs_token=None,
60-
protocols=None,
61-
rate_limited=True,
62-
ip_range_whitelist=None,
63-
supports_ephemeral=False,
64+
token: str,
65+
hostname: str,
66+
id: str,
67+
sender: str,
68+
url: Optional[str] = None,
69+
namespaces: Optional[JsonDict] = None,
70+
hs_token: Optional[str] = None,
71+
protocols: Optional[Iterable[str]] = None,
72+
rate_limited: bool = True,
73+
ip_range_whitelist: Optional[IPSet] = None,
74+
supports_ephemeral: bool = False,
6475
):
6576
self.token = token
6677
self.url = (
@@ -85,27 +96,33 @@ def __init__(
8596

8697
self.rate_limited = rate_limited
8798

88-
def _check_namespaces(self, namespaces):
99+
def _check_namespaces(
100+
self, namespaces: Optional[JsonDict]
101+
) -> Dict[str, List[Namespace]]:
89102
# Sanity check that it is of the form:
90103
# {
91104
# users: [ {regex: "[A-z]+.*", exclusive: true}, ...],
92105
# aliases: [ {regex: "[A-z]+.*", exclusive: true}, ...],
93106
# rooms: [ {regex: "[A-z]+.*", exclusive: true}, ...],
94107
# }
95-
if not namespaces:
108+
if namespaces is None:
96109
namespaces = {}
97110

111+
result: Dict[str, List[Namespace]] = {}
112+
98113
for ns in ApplicationService.NS_LIST:
114+
result[ns] = []
115+
99116
if ns not in namespaces:
100-
namespaces[ns] = []
101117
continue
102118

103-
if type(namespaces[ns]) != list:
119+
if not isinstance(namespaces[ns], list):
104120
raise ValueError("Bad namespace value for '%s'" % ns)
105121
for regex_obj in namespaces[ns]:
106122
if not isinstance(regex_obj, dict):
107123
raise ValueError("Expected dict regex for ns '%s'" % ns)
108-
if not isinstance(regex_obj.get("exclusive"), bool):
124+
exclusive = regex_obj.get("exclusive")
125+
if not isinstance(exclusive, bool):
109126
raise ValueError("Expected bool for 'exclusive' in ns '%s'" % ns)
110127
group_id = regex_obj.get("group_id")
111128
if group_id:
@@ -126,22 +143,26 @@ def _check_namespaces(self, namespaces):
126143
)
127144

128145
regex = regex_obj.get("regex")
129-
if isinstance(regex, str):
130-
regex_obj["regex"] = re.compile(regex) # Pre-compile regex
131-
else:
146+
if not isinstance(regex, str):
132147
raise ValueError("Expected string for 'regex' in ns '%s'" % ns)
133-
return namespaces
134148

135-
def _matches_regex(self, test_string: str, namespace_key: str) -> Optional[Match]:
136-
for regex_obj in self.namespaces[namespace_key]:
137-
if regex_obj["regex"].match(test_string):
138-
return regex_obj
149+
# Pre-compile regex.
150+
result[ns].append(Namespace(exclusive, group_id, re.compile(regex)))
151+
152+
return result
153+
154+
def _matches_regex(
155+
self, namespace_key: str, test_string: str
156+
) -> Optional[Namespace]:
157+
for namespace in self.namespaces[namespace_key]:
158+
if namespace.regex.match(test_string):
159+
return namespace
139160
return None
140161

141-
def _is_exclusive(self, ns_key: str, test_string: str) -> bool:
142-
regex_obj = self._matches_regex(test_string, ns_key)
143-
if regex_obj:
144-
return regex_obj["exclusive"]
162+
def _is_exclusive(self, namespace_key: str, test_string: str) -> bool:
163+
namespace = self._matches_regex(namespace_key, test_string)
164+
if namespace:
165+
return namespace.exclusive
145166
return False
146167

147168
async def _matches_user(
@@ -260,15 +281,15 @@ async def is_interested_in_presence(
260281

261282
def is_interested_in_user(self, user_id: str) -> bool:
262283
return (
263-
bool(self._matches_regex(user_id, ApplicationService.NS_USERS))
284+
bool(self._matches_regex(ApplicationService.NS_USERS, user_id))
264285
or user_id == self.sender
265286
)
266287

267288
def is_interested_in_alias(self, alias: str) -> bool:
268-
return bool(self._matches_regex(alias, ApplicationService.NS_ALIASES))
289+
return bool(self._matches_regex(ApplicationService.NS_ALIASES, alias))
269290

270291
def is_interested_in_room(self, room_id: str) -> bool:
271-
return bool(self._matches_regex(room_id, ApplicationService.NS_ROOMS))
292+
return bool(self._matches_regex(ApplicationService.NS_ROOMS, room_id))
272293

273294
def is_exclusive_user(self, user_id: str) -> bool:
274295
return (
@@ -285,14 +306,14 @@ def is_exclusive_alias(self, alias: str) -> bool:
285306
def is_exclusive_room(self, room_id: str) -> bool:
286307
return self._is_exclusive(ApplicationService.NS_ROOMS, room_id)
287308

288-
def get_exclusive_user_regexes(self):
309+
def get_exclusive_user_regexes(self) -> List[Pattern[str]]:
289310
"""Get the list of regexes used to determine if a user is exclusively
290311
registered by the AS
291312
"""
292313
return [
293-
regex_obj["regex"]
294-
for regex_obj in self.namespaces[ApplicationService.NS_USERS]
295-
if regex_obj["exclusive"]
314+
namespace.regex
315+
for namespace in self.namespaces[ApplicationService.NS_USERS]
316+
if namespace.exclusive
296317
]
297318

298319
def get_groups_for_user(self, user_id: str) -> Iterable[str]:
@@ -305,15 +326,15 @@ def get_groups_for_user(self, user_id: str) -> Iterable[str]:
305326
An iterable that yields group_id strings.
306327
"""
307328
return (
308-
regex_obj["group_id"]
309-
for regex_obj in self.namespaces[ApplicationService.NS_USERS]
310-
if "group_id" in regex_obj and regex_obj["regex"].match(user_id)
329+
namespace.group_id
330+
for namespace in self.namespaces[ApplicationService.NS_USERS]
331+
if namespace.group_id and namespace.regex.match(user_id)
311332
)
312333

313334
def is_rate_limited(self) -> bool:
314335
return self.rate_limited
315336

316-
def __str__(self):
337+
def __str__(self) -> str:
317338
# copy dictionary and redact token fields so they don't get logged
318339
dict_copy = self.__dict__.copy()
319340
dict_copy["token"] = "<redacted>"

synapse/appservice/api.py

+35-13
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import logging
15-
import urllib
16-
from typing import TYPE_CHECKING, List, Optional, Tuple
15+
import urllib.parse
16+
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
1717

1818
from prometheus_client import Counter
1919

@@ -53,15 +53,15 @@
5353
APP_SERVICE_PREFIX = "/_matrix/app/unstable"
5454

5555

56-
def _is_valid_3pe_metadata(info):
56+
def _is_valid_3pe_metadata(info: JsonDict) -> bool:
5757
if "instances" not in info:
5858
return False
5959
if not isinstance(info["instances"], list):
6060
return False
6161
return True
6262

6363

64-
def _is_valid_3pe_result(r, field):
64+
def _is_valid_3pe_result(r: JsonDict, field: str) -> bool:
6565
if not isinstance(r, dict):
6666
return False
6767

@@ -93,9 +93,13 @@ def __init__(self, hs: "HomeServer"):
9393
hs.get_clock(), "as_protocol_meta", timeout_ms=HOUR_IN_MS
9494
)
9595

96-
async def query_user(self, service, user_id):
96+
async def query_user(self, service: "ApplicationService", user_id: str) -> bool:
9797
if service.url is None:
9898
return False
99+
100+
# This is required by the configuration.
101+
assert service.hs_token is not None
102+
99103
uri = service.url + ("/users/%s" % urllib.parse.quote(user_id))
100104
try:
101105
response = await self.get_json(uri, {"access_token": service.hs_token})
@@ -109,9 +113,13 @@ async def query_user(self, service, user_id):
109113
logger.warning("query_user to %s threw exception %s", uri, ex)
110114
return False
111115

112-
async def query_alias(self, service, alias):
116+
async def query_alias(self, service: "ApplicationService", alias: str) -> bool:
113117
if service.url is None:
114118
return False
119+
120+
# This is required by the configuration.
121+
assert service.hs_token is not None
122+
115123
uri = service.url + ("/rooms/%s" % urllib.parse.quote(alias))
116124
try:
117125
response = await self.get_json(uri, {"access_token": service.hs_token})
@@ -125,7 +133,13 @@ async def query_alias(self, service, alias):
125133
logger.warning("query_alias to %s threw exception %s", uri, ex)
126134
return False
127135

128-
async def query_3pe(self, service, kind, protocol, fields):
136+
async def query_3pe(
137+
self,
138+
service: "ApplicationService",
139+
kind: str,
140+
protocol: str,
141+
fields: Dict[bytes, List[bytes]],
142+
) -> List[JsonDict]:
129143
if kind == ThirdPartyEntityKind.USER:
130144
required_field = "userid"
131145
elif kind == ThirdPartyEntityKind.LOCATION:
@@ -205,11 +219,14 @@ async def push_bulk(
205219
events: List[EventBase],
206220
ephemeral: List[JsonDict],
207221
txn_id: Optional[int] = None,
208-
):
222+
) -> bool:
209223
if service.url is None:
210224
return True
211225

212-
events = self._serialize(service, events)
226+
# This is required by the configuration.
227+
assert service.hs_token is not None
228+
229+
serialized_events = self._serialize(service, events)
213230

214231
if txn_id is None:
215232
logger.warning(
@@ -221,9 +238,12 @@ async def push_bulk(
221238

222239
# Never send ephemeral events to appservices that do not support it
223240
if service.supports_ephemeral:
224-
body = {"events": events, "de.sorunome.msc2409.ephemeral": ephemeral}
241+
body = {
242+
"events": serialized_events,
243+
"de.sorunome.msc2409.ephemeral": ephemeral,
244+
}
225245
else:
226-
body = {"events": events}
246+
body = {"events": serialized_events}
227247

228248
try:
229249
await self.put_json(
@@ -238,7 +258,7 @@ async def push_bulk(
238258
[event.get("event_id") for event in events],
239259
)
240260
sent_transactions_counter.labels(service.id).inc()
241-
sent_events_counter.labels(service.id).inc(len(events))
261+
sent_events_counter.labels(service.id).inc(len(serialized_events))
242262
return True
243263
except CodeMessageException as e:
244264
logger.warning(
@@ -260,7 +280,9 @@ async def push_bulk(
260280
failed_transactions_counter.labels(service.id).inc()
261281
return False
262282

263-
def _serialize(self, service, events):
283+
def _serialize(
284+
self, service: "ApplicationService", events: Iterable[EventBase]
285+
) -> List[JsonDict]:
264286
time_now = self.clock.time_msec()
265287
return [
266288
serialize_event(

0 commit comments

Comments
 (0)