Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 15bb1c8

Browse files
authored
Add type hints to synapse/storage/databases/main/stats.py (#11653)
1 parent fcfe675 commit 15bb1c8

File tree

3 files changed

+57
-42
lines changed

3 files changed

+57
-42
lines changed

changelog.d/11653.misc

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add missing type hints to storage classes.

mypy.ini

+3-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ exclude = (?x)
3939
|synapse/storage/databases/main/roommember.py
4040
|synapse/storage/databases/main/search.py
4141
|synapse/storage/databases/main/state.py
42-
|synapse/storage/databases/main/stats.py
4342
|synapse/storage/databases/main/user_directory.py
4443
|synapse/storage/schema/
4544

@@ -214,6 +213,9 @@ disallow_untyped_defs = True
214213
[mypy-synapse.storage.databases.main.profile]
215214
disallow_untyped_defs = True
216215

216+
[mypy-synapse.storage.databases.main.stats]
217+
disallow_untyped_defs = True
218+
217219
[mypy-synapse.storage.databases.main.state_deltas]
218220
disallow_untyped_defs = True
219221

synapse/storage/databases/main/stats.py

+53-41
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,19 @@
1616
import logging
1717
from enum import Enum
1818
from itertools import chain
19-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
19+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast
2020

2121
from typing_extensions import Counter
2222

2323
from twisted.internet.defer import DeferredLock
2424

2525
from synapse.api.constants import EventContentFields, EventTypes, Membership
2626
from synapse.api.errors import StoreError
27-
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
27+
from synapse.storage.database import (
28+
DatabasePool,
29+
LoggingDatabaseConnection,
30+
LoggingTransaction,
31+
)
2832
from synapse.storage.databases.main.state_deltas import StateDeltasStore
2933
from synapse.types import JsonDict
3034
from synapse.util.caches.descriptors import cached
@@ -122,7 +126,9 @@ def __init__(
122126
self.db_pool.updates.register_noop_background_update("populate_stats_cleanup")
123127
self.db_pool.updates.register_noop_background_update("populate_stats_prepare")
124128

125-
async def _populate_stats_process_users(self, progress, batch_size):
129+
async def _populate_stats_process_users(
130+
self, progress: JsonDict, batch_size: int
131+
) -> int:
126132
"""
127133
This is a background update which regenerates statistics for users.
128134
"""
@@ -134,7 +140,7 @@ async def _populate_stats_process_users(self, progress, batch_size):
134140

135141
last_user_id = progress.get("last_user_id", "")
136142

137-
def _get_next_batch(txn):
143+
def _get_next_batch(txn: LoggingTransaction) -> List[str]:
138144
sql = """
139145
SELECT DISTINCT name FROM users
140146
WHERE name > ?
@@ -168,7 +174,9 @@ def _get_next_batch(txn):
168174

169175
return len(users_to_work_on)
170176

171-
async def _populate_stats_process_rooms(self, progress, batch_size):
177+
async def _populate_stats_process_rooms(
178+
self, progress: JsonDict, batch_size: int
179+
) -> int:
172180
"""This is a background update which regenerates statistics for rooms."""
173181
if not self.stats_enabled:
174182
await self.db_pool.updates._end_background_update(
@@ -178,7 +186,7 @@ async def _populate_stats_process_rooms(self, progress, batch_size):
178186

179187
last_room_id = progress.get("last_room_id", "")
180188

181-
def _get_next_batch(txn):
189+
def _get_next_batch(txn: LoggingTransaction) -> List[str]:
182190
sql = """
183191
SELECT DISTINCT room_id FROM current_state_events
184192
WHERE room_id > ?
@@ -307,7 +315,7 @@ async def bulk_update_stats_delta(
307315
stream_id: Current position.
308316
"""
309317

310-
def _bulk_update_stats_delta_txn(txn):
318+
def _bulk_update_stats_delta_txn(txn: LoggingTransaction) -> None:
311319
for stats_type, stats_updates in updates.items():
312320
for stats_id, fields in stats_updates.items():
313321
logger.debug(
@@ -339,7 +347,7 @@ async def update_stats_delta(
339347
stats_type: str,
340348
stats_id: str,
341349
fields: Dict[str, int],
342-
complete_with_stream_id: Optional[int],
350+
complete_with_stream_id: int,
343351
absolute_field_overrides: Optional[Dict[str, int]] = None,
344352
) -> None:
345353
"""
@@ -372,14 +380,14 @@ async def update_stats_delta(
372380

373381
def _update_stats_delta_txn(
374382
self,
375-
txn,
376-
ts,
377-
stats_type,
378-
stats_id,
379-
fields,
380-
complete_with_stream_id,
381-
absolute_field_overrides=None,
382-
):
383+
txn: LoggingTransaction,
384+
ts: int,
385+
stats_type: str,
386+
stats_id: str,
387+
fields: Dict[str, int],
388+
complete_with_stream_id: int,
389+
absolute_field_overrides: Optional[Dict[str, int]] = None,
390+
) -> None:
383391
if absolute_field_overrides is None:
384392
absolute_field_overrides = {}
385393

@@ -422,20 +430,23 @@ def _update_stats_delta_txn(
422430
)
423431

424432
def _upsert_with_additive_relatives_txn(
425-
self, txn, table, keyvalues, absolutes, additive_relatives
426-
):
433+
self,
434+
txn: LoggingTransaction,
435+
table: str,
436+
keyvalues: Dict[str, Any],
437+
absolutes: Dict[str, Any],
438+
additive_relatives: Dict[str, int],
439+
) -> None:
427440
"""Used to update values in the stats tables.
428441
429442
This is basically a slightly convoluted upsert that *adds* to any
430443
existing rows.
431444
432445
Args:
433-
txn
434-
table (str): Table name
435-
keyvalues (dict[str, any]): Row-identifying key values
436-
absolutes (dict[str, any]): Absolute (set) fields
437-
additive_relatives (dict[str, int]): Fields that will be added onto
438-
if existing row present.
446+
table: Table name
447+
keyvalues: Row-identifying key values
448+
absolutes: Absolute (set) fields
449+
additive_relatives: Fields that will be added onto if existing row present.
439450
"""
440451
if self.database_engine.can_native_upsert:
441452
absolute_updates = [
@@ -491,20 +502,17 @@ def _upsert_with_additive_relatives_txn(
491502
current_row.update(absolutes)
492503
self.db_pool.simple_update_one_txn(txn, table, keyvalues, current_row)
493504

494-
async def _calculate_and_set_initial_state_for_room(
495-
self, room_id: str
496-
) -> Tuple[dict, dict, int]:
505+
async def _calculate_and_set_initial_state_for_room(self, room_id: str) -> None:
497506
"""Calculate and insert an entry into room_stats_current.
498507
499508
Args:
500509
room_id: The room ID under calculation.
501-
502-
Returns:
503-
A tuple of room state, membership counts and stream position.
504510
"""
505511

506-
def _fetch_current_state_stats(txn):
507-
pos = self.get_room_max_stream_ordering()
512+
def _fetch_current_state_stats(
513+
txn: LoggingTransaction,
514+
) -> Tuple[List[str], Dict[str, int], int, List[str], int]:
515+
pos = self.get_room_max_stream_ordering() # type: ignore[attr-defined]
508516

509517
rows = self.db_pool.simple_select_many_txn(
510518
txn,
@@ -524,7 +532,7 @@ def _fetch_current_state_stats(txn):
524532
retcols=["event_id"],
525533
)
526534

527-
event_ids = [row["event_id"] for row in rows]
535+
event_ids = cast(List[str], [row["event_id"] for row in rows])
528536

529537
txn.execute(
530538
"""
@@ -544,9 +552,9 @@ def _fetch_current_state_stats(txn):
544552
(room_id,),
545553
)
546554

547-
(current_state_events_count,) = txn.fetchone()
555+
current_state_events_count = cast(Tuple[int], txn.fetchone())[0]
548556

549-
users_in_room = self.get_users_in_room_txn(txn, room_id)
557+
users_in_room = self.get_users_in_room_txn(txn, room_id) # type: ignore[attr-defined]
550558

551559
return (
552560
event_ids,
@@ -566,7 +574,7 @@ def _fetch_current_state_stats(txn):
566574
"get_initial_state_for_room", _fetch_current_state_stats
567575
)
568576

569-
state_event_map = await self.get_events(event_ids, get_prev_content=False)
577+
state_event_map = await self.get_events(event_ids, get_prev_content=False) # type: ignore[attr-defined]
570578

571579
room_state = {
572580
"join_rules": None,
@@ -622,8 +630,10 @@ def _fetch_current_state_stats(txn):
622630
},
623631
)
624632

625-
async def _calculate_and_set_initial_state_for_user(self, user_id):
626-
def _calculate_and_set_initial_state_for_user_txn(txn):
633+
async def _calculate_and_set_initial_state_for_user(self, user_id: str) -> None:
634+
def _calculate_and_set_initial_state_for_user_txn(
635+
txn: LoggingTransaction,
636+
) -> Tuple[int, int]:
627637
pos = self._get_max_stream_id_in_current_state_deltas_txn(txn)
628638

629639
txn.execute(
@@ -634,7 +644,7 @@ def _calculate_and_set_initial_state_for_user_txn(txn):
634644
""",
635645
(user_id,),
636646
)
637-
(count,) = txn.fetchone()
647+
count = cast(Tuple[int], txn.fetchone())[0]
638648
return count, pos
639649

640650
joined_rooms, pos = await self.db_pool.runInteraction(
@@ -678,7 +688,9 @@ async def get_users_media_usage_paginate(
678688
users that exist given this query
679689
"""
680690

681-
def get_users_media_usage_paginate_txn(txn):
691+
def get_users_media_usage_paginate_txn(
692+
txn: LoggingTransaction,
693+
) -> Tuple[List[JsonDict], int]:
682694
filters = []
683695
args = [self.hs.config.server.server_name]
684696

@@ -733,7 +745,7 @@ def get_users_media_usage_paginate_txn(txn):
733745
sql_base=sql_base,
734746
)
735747
txn.execute(sql, args)
736-
count = txn.fetchone()[0]
748+
count = cast(Tuple[int], txn.fetchone())[0]
737749

738750
sql = """
739751
SELECT

0 commit comments

Comments
 (0)