Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 6edefef

Browse files
authored
Add some type hints to datastore (#12717)
1 parent 942c30b commit 6edefef

File tree

10 files changed

+254
-161
lines changed

10 files changed

+254
-161
lines changed

changelog.d/12717.misc

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add some type hints to datastore.

mypy.ini

-2
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,6 @@ exclude = (?x)
2828
|synapse/storage/databases/main/cache.py
2929
|synapse/storage/databases/main/devices.py
3030
|synapse/storage/databases/main/event_federation.py
31-
|synapse/storage/databases/main/push_rule.py
32-
|synapse/storage/databases/main/roommember.py
3331
|synapse/storage/schema/
3432

3533
|tests/api/test_auth.py

synapse/federation/sender/__init__.py

+17-7
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,17 @@
1515
import abc
1616
import logging
1717
from collections import OrderedDict
18-
from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Set, Tuple
18+
from typing import (
19+
TYPE_CHECKING,
20+
Collection,
21+
Dict,
22+
Hashable,
23+
Iterable,
24+
List,
25+
Optional,
26+
Set,
27+
Tuple,
28+
)
1929

2030
import attr
2131
from prometheus_client import Counter
@@ -409,7 +419,7 @@ async def handle_event(event: EventBase) -> None:
409419
)
410420
return
411421

412-
destinations: Optional[Set[str]] = None
422+
destinations: Optional[Collection[str]] = None
413423
if not event.prev_event_ids():
414424
# If there are no prev event IDs then the state is empty
415425
# and so no remote servers in the room
@@ -444,7 +454,7 @@ async def handle_event(event: EventBase) -> None:
444454
)
445455
return
446456

447-
destinations = {
457+
sharded_destinations = {
448458
d
449459
for d in destinations
450460
if self._federation_shard_config.should_handle(
@@ -456,12 +466,12 @@ async def handle_event(event: EventBase) -> None:
456466
# If we are sending the event on behalf of another server
457467
# then it already has the event and there is no reason to
458468
# send the event to it.
459-
destinations.discard(send_on_behalf_of)
469+
sharded_destinations.discard(send_on_behalf_of)
460470

461-
logger.debug("Sending %s to %r", event, destinations)
471+
logger.debug("Sending %s to %r", event, sharded_destinations)
462472

463-
if destinations:
464-
await self._send_pdu(event, destinations)
473+
if sharded_destinations:
474+
await self._send_pdu(event, sharded_destinations)
465475

466476
now = self.clock.time_msec()
467477
ts = await self.store.get_received_ts(event.event_id)

synapse/handlers/sync.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -411,10 +411,10 @@ async def current_sync_for_user(
411411
set_tag(SynapseTags.SYNC_RESULT, bool(sync_result))
412412
return sync_result
413413

414-
async def push_rules_for_user(self, user: UserID) -> JsonDict:
414+
async def push_rules_for_user(self, user: UserID) -> Dict[str, Dict[str, list]]:
415415
user_id = user.to_string()
416-
rules = await self.store.get_push_rules_for_user(user_id)
417-
rules = format_push_rules_for_user(user, rules)
416+
rules_raw = await self.store.get_push_rules_for_user(user_id)
417+
rules = format_push_rules_for_user(user, rules_raw)
418418
return rules
419419

420420
async def ephemeral_by_room(

synapse/rest/client/push_rule.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,9 @@ async def on_GET(self, request: SynapseRequest, path: str) -> Tuple[int, JsonDic
148148
# we build up the full structure and then decide which bits of it
149149
# to send which means doing unnecessary work sometimes but is
150150
# is probably not going to make a whole lot of difference
151-
rules = await self.store.get_push_rules_for_user(user_id)
151+
rules_raw = await self.store.get_push_rules_for_user(user_id)
152152

153-
rules = format_push_rules_for_user(requester.user, rules)
153+
rules = format_push_rules_for_user(requester.user, rules_raw)
154154

155155
path_parts = path.split("/")[1:]
156156

synapse/state/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -239,13 +239,13 @@ async def get_current_users_in_room(
239239
entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
240240
return await self.store.get_joined_users_from_state(room_id, entry)
241241

242-
async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
242+
async def get_current_hosts_in_room(self, room_id: str) -> FrozenSet[str]:
243243
event_ids = await self.store.get_latest_event_ids_in_room(room_id)
244244
return await self.get_hosts_in_room_at_events(room_id, event_ids)
245245

246246
async def get_hosts_in_room_at_events(
247247
self, room_id: str, event_ids: Collection[str]
248-
) -> Set[str]:
248+
) -> FrozenSet[str]:
249249
"""Get the hosts that were in a room at the given event ids
250250
251251
Args:

synapse/storage/databases/main/__init__.py

+1-7
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,7 @@
2626
from synapse.storage.databases.main.stats import UserSortOrder
2727
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
2828
from synapse.storage.types import Cursor
29-
from synapse.storage.util.id_generators import (
30-
IdGenerator,
31-
MultiWriterIdGenerator,
32-
StreamIdGenerator,
33-
)
29+
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
3430
from synapse.types import JsonDict, get_domain_from_id
3531
from synapse.util.caches.stream_change_cache import StreamChangeCache
3632

@@ -155,8 +151,6 @@ def __init__(
155151
],
156152
)
157153

158-
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
159-
self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
160154
self._group_updates_id_gen = StreamIdGenerator(
161155
db_conn, "local_group_updates", "stream_id"
162156
)

synapse/storage/databases/main/metrics.py

+28-28
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,19 @@
1414
import calendar
1515
import logging
1616
import time
17-
from typing import TYPE_CHECKING, Dict
17+
from typing import TYPE_CHECKING, Dict, List, Tuple, cast
1818

1919
from synapse.metrics import GaugeBucketCollector
2020
from synapse.metrics.background_process_metrics import wrap_as_background_process
2121
from synapse.storage._base import SQLBaseStore
22-
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
22+
from synapse.storage.database import (
23+
DatabasePool,
24+
LoggingDatabaseConnection,
25+
LoggingTransaction,
26+
)
2327
from synapse.storage.databases.main.event_push_actions import (
2428
EventPushActionsWorkerStore,
2529
)
26-
from synapse.storage.types import Cursor
2730

2831
if TYPE_CHECKING:
2932
from synapse.server import HomeServer
@@ -73,7 +76,7 @@ def __init__(
7376

7477
@wrap_as_background_process("read_forward_extremities")
7578
async def _read_forward_extremities(self) -> None:
76-
def fetch(txn):
79+
def fetch(txn: LoggingTransaction) -> List[Tuple[int, int]]:
7780
txn.execute(
7881
"""
7982
SELECT t1.c, t2.c
@@ -86,7 +89,7 @@ def fetch(txn):
8689
) t2 ON t1.room_id = t2.room_id
8790
"""
8891
)
89-
return txn.fetchall()
92+
return cast(List[Tuple[int, int]], txn.fetchall())
9093

9194
res = await self.db_pool.runInteraction("read_forward_extremities", fetch)
9295

@@ -104,20 +107,20 @@ async def count_daily_e2ee_messages(self) -> int:
104107
call to this function, it will return None.
105108
"""
106109

107-
def _count_messages(txn):
110+
def _count_messages(txn: LoggingTransaction) -> int:
108111
sql = """
109112
SELECT COUNT(*) FROM events
110113
WHERE type = 'm.room.encrypted'
111114
AND stream_ordering > ?
112115
"""
113116
txn.execute(sql, (self.stream_ordering_day_ago,))
114-
(count,) = txn.fetchone()
117+
(count,) = cast(Tuple[int], txn.fetchone())
115118
return count
116119

117120
return await self.db_pool.runInteraction("count_e2ee_messages", _count_messages)
118121

119122
async def count_daily_sent_e2ee_messages(self) -> int:
120-
def _count_messages(txn):
123+
def _count_messages(txn: LoggingTransaction) -> int:
121124
# This is good enough as if you have silly characters in your own
122125
# hostname then that's your own fault.
123126
like_clause = "%:" + self.hs.hostname
@@ -130,22 +133,22 @@ def _count_messages(txn):
130133
"""
131134

132135
txn.execute(sql, (like_clause, self.stream_ordering_day_ago))
133-
(count,) = txn.fetchone()
136+
(count,) = cast(Tuple[int], txn.fetchone())
134137
return count
135138

136139
return await self.db_pool.runInteraction(
137140
"count_daily_sent_e2ee_messages", _count_messages
138141
)
139142

140143
async def count_daily_active_e2ee_rooms(self) -> int:
141-
def _count(txn):
144+
def _count(txn: LoggingTransaction) -> int:
142145
sql = """
143146
SELECT COUNT(DISTINCT room_id) FROM events
144147
WHERE type = 'm.room.encrypted'
145148
AND stream_ordering > ?
146149
"""
147150
txn.execute(sql, (self.stream_ordering_day_ago,))
148-
(count,) = txn.fetchone()
151+
(count,) = cast(Tuple[int], txn.fetchone())
149152
return count
150153

151154
return await self.db_pool.runInteraction(
@@ -160,20 +163,20 @@ async def count_daily_messages(self) -> int:
160163
call to this function, it will return None.
161164
"""
162165

163-
def _count_messages(txn):
166+
def _count_messages(txn: LoggingTransaction) -> int:
164167
sql = """
165168
SELECT COUNT(*) FROM events
166169
WHERE type = 'm.room.message'
167170
AND stream_ordering > ?
168171
"""
169172
txn.execute(sql, (self.stream_ordering_day_ago,))
170-
(count,) = txn.fetchone()
173+
(count,) = cast(Tuple[int], txn.fetchone())
171174
return count
172175

173176
return await self.db_pool.runInteraction("count_messages", _count_messages)
174177

175178
async def count_daily_sent_messages(self) -> int:
176-
def _count_messages(txn):
179+
def _count_messages(txn: LoggingTransaction) -> int:
177180
# This is good enough as if you have silly characters in your own
178181
# hostname then that's your own fault.
179182
like_clause = "%:" + self.hs.hostname
@@ -186,22 +189,22 @@ def _count_messages(txn):
186189
"""
187190

188191
txn.execute(sql, (like_clause, self.stream_ordering_day_ago))
189-
(count,) = txn.fetchone()
192+
(count,) = cast(Tuple[int], txn.fetchone())
190193
return count
191194

192195
return await self.db_pool.runInteraction(
193196
"count_daily_sent_messages", _count_messages
194197
)
195198

196199
async def count_daily_active_rooms(self) -> int:
197-
def _count(txn):
200+
def _count(txn: LoggingTransaction) -> int:
198201
sql = """
199202
SELECT COUNT(DISTINCT room_id) FROM events
200203
WHERE type = 'm.room.message'
201204
AND stream_ordering > ?
202205
"""
203206
txn.execute(sql, (self.stream_ordering_day_ago,))
204-
(count,) = txn.fetchone()
207+
(count,) = cast(Tuple[int], txn.fetchone())
205208
return count
206209

207210
return await self.db_pool.runInteraction("count_daily_active_rooms", _count)
@@ -227,7 +230,7 @@ async def count_monthly_users(self) -> int:
227230
"count_monthly_users", self._count_users, thirty_days_ago
228231
)
229232

230-
def _count_users(self, txn: Cursor, time_from: int) -> int:
233+
def _count_users(self, txn: LoggingTransaction, time_from: int) -> int:
231234
"""
232235
Returns number of users seen in the past time_from period
233236
"""
@@ -242,7 +245,7 @@ def _count_users(self, txn: Cursor, time_from: int) -> int:
242245
# Mypy knows that fetchone() might return None if there are no rows.
243246
# We know better: "SELECT COUNT(...) FROM ..." without any GROUP BY always
244247
# returns exactly one row.
245-
(count,) = txn.fetchone() # type: ignore[misc]
248+
(count,) = cast(Tuple[int], txn.fetchone())
246249
return count
247250

248251
async def count_r30_users(self) -> Dict[str, int]:
@@ -256,7 +259,7 @@ async def count_r30_users(self) -> Dict[str, int]:
256259
A mapping of counts globally as well as broken out by platform.
257260
"""
258261

259-
def _count_r30_users(txn):
262+
def _count_r30_users(txn: LoggingTransaction) -> Dict[str, int]:
260263
thirty_days_in_secs = 86400 * 30
261264
now = int(self._clock.time())
262265
thirty_days_ago_in_secs = now - thirty_days_in_secs
@@ -321,7 +324,7 @@ def _count_r30_users(txn):
321324

322325
txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
323326

324-
(count,) = txn.fetchone()
327+
(count,) = cast(Tuple[int], txn.fetchone())
325328
results["all"] = count
326329

327330
return results
@@ -348,7 +351,7 @@ async def count_r30v2_users(self) -> Dict[str, int]:
348351
- "web" (any web application -- it's not possible to distinguish Element Web here)
349352
"""
350353

351-
def _count_r30v2_users(txn):
354+
def _count_r30v2_users(txn: LoggingTransaction) -> Dict[str, int]:
352355
thirty_days_in_secs = 86400 * 30
353356
now = int(self._clock.time())
354357
sixty_days_ago_in_secs = now - 2 * thirty_days_in_secs
@@ -445,11 +448,8 @@ def _count_r30v2_users(txn):
445448
thirty_days_in_secs * 1000,
446449
),
447450
)
448-
row = txn.fetchone()
449-
if row is None:
450-
results["all"] = 0
451-
else:
452-
results["all"] = row[0]
451+
(count,) = cast(Tuple[int], txn.fetchone())
452+
results["all"] = count
453453

454454
return results
455455

@@ -471,7 +471,7 @@ async def generate_user_daily_visits(self) -> None:
471471
Generates daily visit data for use in cohort/ retention analysis
472472
"""
473473

474-
def _generate_user_daily_visits(txn):
474+
def _generate_user_daily_visits(txn: LoggingTransaction) -> None:
475475
logger.info("Calling _generate_user_daily_visits")
476476
today_start = self._get_start_of_day()
477477
a_day_in_milliseconds = 24 * 60 * 60 * 1000

0 commit comments

Comments
 (0)