Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 1f55c04

Browse files
authored
Improve type hints for cached decorator. (#15658)
The cached decorators always return a Deferred, which was not properly propagated. It was close enough when wrapping coroutines, but failed if a bare function was wrapped.
1 parent 379eb2d commit 1f55c04

File tree

6 files changed

+73
-63
lines changed

6 files changed

+73
-63
lines changed

changelog.d/15658.misc

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Improve type hints.

scripts-dev/mypy_synapse_plugin.py

+33-1
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,11 @@
1818

1919
from typing import Callable, Optional, Type
2020

21+
from mypy.erasetype import remove_instance_last_known_values
2122
from mypy.nodes import ARG_NAMED_OPT
2223
from mypy.plugin import MethodSigContext, Plugin
2324
from mypy.typeops import bind_self
24-
from mypy.types import CallableType, NoneType, UnionType
25+
from mypy.types import CallableType, Instance, NoneType, UnionType
2526

2627

2728
class SynapsePlugin(Plugin):
@@ -92,10 +93,41 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType:
9293
arg_names.append("on_invalidate")
9394
arg_kinds.append(ARG_NAMED_OPT) # Arg is an optional kwarg.
9495

96+
# Finally we ensure the return type is a Deferred.
97+
if (
98+
isinstance(signature.ret_type, Instance)
99+
and signature.ret_type.type.fullname == "twisted.internet.defer.Deferred"
100+
):
101+
# If it is already a Deferred, nothing to do.
102+
ret_type = signature.ret_type
103+
else:
104+
ret_arg = None
105+
if isinstance(signature.ret_type, Instance):
106+
# If a coroutine, wrap the coroutine's return type in a Deferred.
107+
if signature.ret_type.type.fullname == "typing.Coroutine":
108+
ret_arg = signature.ret_type.args[2]
109+
110+
# If an awaitable, wrap the awaitable's final value in a Deferred.
111+
elif signature.ret_type.type.fullname == "typing.Awaitable":
112+
ret_arg = signature.ret_type.args[0]
113+
114+
# Otherwise, wrap the return value in a Deferred.
115+
if ret_arg is None:
116+
ret_arg = signature.ret_type
117+
118+
# This should be able to use ctx.api.named_generic_type, but that doesn't seem
119+
# to find the correct symbol for anything more than 1 module deep.
120+
#
121+
# modules is not part of CheckerPluginInterface. The following is a combination
122+
# of TypeChecker.named_generic_type and TypeChecker.lookup_typeinfo.
123+
sym = ctx.api.modules["twisted.internet.defer"].names.get("Deferred") # type: ignore[attr-defined]
124+
ret_type = Instance(sym.node, [remove_instance_last_known_values(ret_arg)])
125+
95126
signature = signature.copy_modified(
96127
arg_types=arg_types,
97128
arg_names=arg_names,
98129
arg_kinds=arg_kinds,
130+
ret_type=ret_type,
99131
)
100132

101133
return signature

synapse/storage/databases/main/roommember.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1099,7 +1099,7 @@ async def _get_joined_hosts(
10991099
# `get_joined_hosts` is called with the "current" state group for the
11001100
# room, and so consecutive calls will be for consecutive state groups
11011101
# which point to the previous state group.
1102-
cache = await self._get_joined_hosts_cache(room_id) # type: ignore[misc]
1102+
cache = await self._get_joined_hosts_cache(room_id)
11031103

11041104
# If the state group in the cache matches, we already have the data we need.
11051105
if state_entry.state_group == cache.state_group:

synapse/util/caches/descriptors.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,9 @@ def __init__(
220220
self.iterable = iterable
221221
self.prune_unread_entries = prune_unread_entries
222222

223-
def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
223+
def __get__(
224+
self, obj: Optional[Any], owner: Optional[Type]
225+
) -> Callable[..., "defer.Deferred[Any]"]:
224226
cache: DeferredCache[CacheKey, Any] = DeferredCache(
225227
name=self.name,
226228
max_entries=self.max_entries,
@@ -232,7 +234,7 @@ def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., An
232234
get_cache_key = self.cache_key_builder
233235

234236
@functools.wraps(self.orig)
235-
def _wrapped(*args: Any, **kwargs: Any) -> Any:
237+
def _wrapped(*args: Any, **kwargs: Any) -> "defer.Deferred[Any]":
236238
# If we're passed a cache_context then we'll want to call its invalidate()
237239
# whenever we are invalidated
238240
invalidate_callback = kwargs.pop("on_invalidate", None)

tests/appservice/test_appservice.py

+29-53
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import re
15-
from typing import Generator
15+
from typing import Any, Generator
1616
from unittest.mock import Mock
1717

1818
from twisted.internet import defer
@@ -49,93 +49,81 @@ def setUp(self) -> None:
4949
@defer.inlineCallbacks
5050
def test_regex_user_id_prefix_match(
5151
self,
52-
) -> Generator["defer.Deferred[object]", object, None]:
52+
) -> Generator["defer.Deferred[Any]", object, None]:
5353
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
5454
self.event.sender = "@irc_foobar:matrix.org"
5555
self.assertTrue(
5656
(
57-
yield defer.ensureDeferred(
58-
self.service.is_interested_in_event(
59-
self.event.event_id, self.event, self.store
60-
)
57+
yield self.service.is_interested_in_event(
58+
self.event.event_id, self.event, self.store
6159
)
6260
)
6361
)
6462

6563
@defer.inlineCallbacks
6664
def test_regex_user_id_prefix_no_match(
6765
self,
68-
) -> Generator["defer.Deferred[object]", object, None]:
66+
) -> Generator["defer.Deferred[Any]", object, None]:
6967
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
7068
self.event.sender = "@someone_else:matrix.org"
7169
self.assertFalse(
7270
(
73-
yield defer.ensureDeferred(
74-
self.service.is_interested_in_event(
75-
self.event.event_id, self.event, self.store
76-
)
71+
yield self.service.is_interested_in_event(
72+
self.event.event_id, self.event, self.store
7773
)
7874
)
7975
)
8076

8177
@defer.inlineCallbacks
8278
def test_regex_room_member_is_checked(
8379
self,
84-
) -> Generator["defer.Deferred[object]", object, None]:
80+
) -> Generator["defer.Deferred[Any]", object, None]:
8581
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
8682
self.event.sender = "@someone_else:matrix.org"
8783
self.event.type = "m.room.member"
8884
self.event.state_key = "@irc_foobar:matrix.org"
8985
self.assertTrue(
9086
(
91-
yield defer.ensureDeferred(
92-
self.service.is_interested_in_event(
93-
self.event.event_id, self.event, self.store
94-
)
87+
yield self.service.is_interested_in_event(
88+
self.event.event_id, self.event, self.store
9589
)
9690
)
9791
)
9892

9993
@defer.inlineCallbacks
10094
def test_regex_room_id_match(
10195
self,
102-
) -> Generator["defer.Deferred[object]", object, None]:
96+
) -> Generator["defer.Deferred[Any]", object, None]:
10397
self.service.namespaces[ApplicationService.NS_ROOMS].append(
10498
_regex("!some_prefix.*some_suffix:matrix.org")
10599
)
106100
self.event.room_id = "!some_prefixs0m3th1nGsome_suffix:matrix.org"
107101
self.assertTrue(
108102
(
109-
yield defer.ensureDeferred(
110-
self.service.is_interested_in_event(
111-
self.event.event_id, self.event, self.store
112-
)
103+
yield self.service.is_interested_in_event(
104+
self.event.event_id, self.event, self.store
113105
)
114106
)
115107
)
116108

117109
@defer.inlineCallbacks
118110
def test_regex_room_id_no_match(
119111
self,
120-
) -> Generator["defer.Deferred[object]", object, None]:
112+
) -> Generator["defer.Deferred[Any]", object, None]:
121113
self.service.namespaces[ApplicationService.NS_ROOMS].append(
122114
_regex("!some_prefix.*some_suffix:matrix.org")
123115
)
124116
self.event.room_id = "!XqBunHwQIXUiqCaoxq:matrix.org"
125117
self.assertFalse(
126118
(
127-
yield defer.ensureDeferred(
128-
self.service.is_interested_in_event(
129-
self.event.event_id, self.event, self.store
130-
)
119+
yield self.service.is_interested_in_event(
120+
self.event.event_id, self.event, self.store
131121
)
132122
)
133123
)
134124

135125
@defer.inlineCallbacks
136-
def test_regex_alias_match(
137-
self,
138-
) -> Generator["defer.Deferred[object]", object, None]:
126+
def test_regex_alias_match(self) -> Generator["defer.Deferred[Any]", object, None]:
139127
self.service.namespaces[ApplicationService.NS_ALIASES].append(
140128
_regex("#irc_.*:matrix.org")
141129
)
@@ -145,10 +133,8 @@ def test_regex_alias_match(
145133
self.store.get_local_users_in_room = simple_async_mock([])
146134
self.assertTrue(
147135
(
148-
yield defer.ensureDeferred(
149-
self.service.is_interested_in_event(
150-
self.event.event_id, self.event, self.store
151-
)
136+
yield self.service.is_interested_in_event(
137+
self.event.event_id, self.event, self.store
152138
)
153139
)
154140
)
@@ -192,7 +178,7 @@ def test_exclusive_room(self) -> None:
192178
@defer.inlineCallbacks
193179
def test_regex_alias_no_match(
194180
self,
195-
) -> Generator["defer.Deferred[object]", object, None]:
181+
) -> Generator["defer.Deferred[Any]", object, None]:
196182
self.service.namespaces[ApplicationService.NS_ALIASES].append(
197183
_regex("#irc_.*:matrix.org")
198184
)
@@ -213,7 +199,7 @@ def test_regex_alias_no_match(
213199
@defer.inlineCallbacks
214200
def test_regex_multiple_matches(
215201
self,
216-
) -> Generator["defer.Deferred[object]", object, None]:
202+
) -> Generator["defer.Deferred[Any]", object, None]:
217203
self.service.namespaces[ApplicationService.NS_ALIASES].append(
218204
_regex("#irc_.*:matrix.org")
219205
)
@@ -223,18 +209,14 @@ def test_regex_multiple_matches(
223209
self.store.get_local_users_in_room = simple_async_mock([])
224210
self.assertTrue(
225211
(
226-
yield defer.ensureDeferred(
227-
self.service.is_interested_in_event(
228-
self.event.event_id, self.event, self.store
229-
)
212+
yield self.service.is_interested_in_event(
213+
self.event.event_id, self.event, self.store
230214
)
231215
)
232216
)
233217

234218
@defer.inlineCallbacks
235-
def test_interested_in_self(
236-
self,
237-
) -> Generator["defer.Deferred[object]", object, None]:
219+
def test_interested_in_self(self) -> Generator["defer.Deferred[Any]", object, None]:
238220
# make sure invites get through
239221
self.service.sender = "@appservice:name"
240222
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
@@ -243,18 +225,14 @@ def test_interested_in_self(
243225
self.event.state_key = self.service.sender
244226
self.assertTrue(
245227
(
246-
yield defer.ensureDeferred(
247-
self.service.is_interested_in_event(
248-
self.event.event_id, self.event, self.store
249-
)
228+
yield self.service.is_interested_in_event(
229+
self.event.event_id, self.event, self.store
250230
)
251231
)
252232
)
253233

254234
@defer.inlineCallbacks
255-
def test_member_list_match(
256-
self,
257-
) -> Generator["defer.Deferred[object]", object, None]:
235+
def test_member_list_match(self) -> Generator["defer.Deferred[Any]", object, None]:
258236
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
259237
# Note that @irc_fo:here is the AS user.
260238
self.store.get_local_users_in_room = simple_async_mock(
@@ -265,10 +243,8 @@ def test_member_list_match(
265243
self.event.sender = "@xmpp_foobar:matrix.org"
266244
self.assertTrue(
267245
(
268-
yield defer.ensureDeferred(
269-
self.service.is_interested_in_event(
270-
self.event.event_id, self.event, self.store
271-
)
246+
yield self.service.is_interested_in_event(
247+
self.event.event_id, self.event, self.store
272248
)
273249
)
274250
)

tests/storage/test_transactions.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,14 @@ def test_get_set_transactions(self) -> None:
3333
destination retries, as well as testing tht we can set and get
3434
correctly.
3535
"""
36-
d = self.store.get_destination_retry_timings("example.com")
37-
r = self.get_success(d)
36+
r = self.get_success(self.store.get_destination_retry_timings("example.com"))
3837
self.assertIsNone(r)
3938

40-
d = self.store.set_destination_retry_timings("example.com", 1000, 50, 100)
41-
self.get_success(d)
39+
self.get_success(
40+
self.store.set_destination_retry_timings("example.com", 1000, 50, 100)
41+
)
4242

43-
d = self.store.get_destination_retry_timings("example.com")
44-
r = self.get_success(d)
43+
r = self.get_success(self.store.get_destination_retry_timings("example.com"))
4544

4645
self.assertEqual(
4746
DestinationRetryTimings(

0 commit comments

Comments
 (0)