16
16
import logging
17
17
from enum import Enum
18
18
from itertools import chain
19
- from typing import TYPE_CHECKING , Any , Dict , List , Optional , Tuple
19
+ from typing import TYPE_CHECKING , Any , Dict , List , Optional , Tuple , cast
20
20
21
21
from typing_extensions import Counter
22
22
23
23
from twisted .internet .defer import DeferredLock
24
24
25
25
from synapse .api .constants import EventContentFields , EventTypes , Membership
26
26
from synapse .api .errors import StoreError
27
- from synapse .storage .database import DatabasePool , LoggingDatabaseConnection
27
+ from synapse .storage .database import (
28
+ DatabasePool ,
29
+ LoggingDatabaseConnection ,
30
+ LoggingTransaction ,
31
+ )
28
32
from synapse .storage .databases .main .state_deltas import StateDeltasStore
29
33
from synapse .types import JsonDict
30
34
from synapse .util .caches .descriptors import cached
@@ -122,7 +126,9 @@ def __init__(
122
126
self .db_pool .updates .register_noop_background_update ("populate_stats_cleanup" )
123
127
self .db_pool .updates .register_noop_background_update ("populate_stats_prepare" )
124
128
125
- async def _populate_stats_process_users (self , progress , batch_size ):
129
+ async def _populate_stats_process_users (
130
+ self , progress : JsonDict , batch_size : int
131
+ ) -> int :
126
132
"""
127
133
This is a background update which regenerates statistics for users.
128
134
"""
@@ -134,7 +140,7 @@ async def _populate_stats_process_users(self, progress, batch_size):
134
140
135
141
last_user_id = progress .get ("last_user_id" , "" )
136
142
137
- def _get_next_batch (txn ) :
143
+ def _get_next_batch (txn : LoggingTransaction ) -> List [ str ] :
138
144
sql = """
139
145
SELECT DISTINCT name FROM users
140
146
WHERE name > ?
@@ -168,7 +174,9 @@ def _get_next_batch(txn):
168
174
169
175
return len (users_to_work_on )
170
176
171
- async def _populate_stats_process_rooms (self , progress , batch_size ):
177
+ async def _populate_stats_process_rooms (
178
+ self , progress : JsonDict , batch_size : int
179
+ ) -> int :
172
180
"""This is a background update which regenerates statistics for rooms."""
173
181
if not self .stats_enabled :
174
182
await self .db_pool .updates ._end_background_update (
@@ -178,7 +186,7 @@ async def _populate_stats_process_rooms(self, progress, batch_size):
178
186
179
187
last_room_id = progress .get ("last_room_id" , "" )
180
188
181
- def _get_next_batch (txn ) :
189
+ def _get_next_batch (txn : LoggingTransaction ) -> List [ str ] :
182
190
sql = """
183
191
SELECT DISTINCT room_id FROM current_state_events
184
192
WHERE room_id > ?
@@ -307,7 +315,7 @@ async def bulk_update_stats_delta(
307
315
stream_id: Current position.
308
316
"""
309
317
310
- def _bulk_update_stats_delta_txn (txn ) :
318
+ def _bulk_update_stats_delta_txn (txn : LoggingTransaction ) -> None :
311
319
for stats_type , stats_updates in updates .items ():
312
320
for stats_id , fields in stats_updates .items ():
313
321
logger .debug (
@@ -339,7 +347,7 @@ async def update_stats_delta(
339
347
stats_type : str ,
340
348
stats_id : str ,
341
349
fields : Dict [str , int ],
342
- complete_with_stream_id : Optional [ int ] ,
350
+ complete_with_stream_id : int ,
343
351
absolute_field_overrides : Optional [Dict [str , int ]] = None ,
344
352
) -> None :
345
353
"""
@@ -372,14 +380,14 @@ async def update_stats_delta(
372
380
373
381
def _update_stats_delta_txn (
374
382
self ,
375
- txn ,
376
- ts ,
377
- stats_type ,
378
- stats_id ,
379
- fields ,
380
- complete_with_stream_id ,
381
- absolute_field_overrides = None ,
382
- ):
383
+ txn : LoggingTransaction ,
384
+ ts : int ,
385
+ stats_type : str ,
386
+ stats_id : str ,
387
+ fields : Dict [ str , int ] ,
388
+ complete_with_stream_id : int ,
389
+ absolute_field_overrides : Optional [ Dict [ str , int ]] = None ,
390
+ ) -> None :
383
391
if absolute_field_overrides is None :
384
392
absolute_field_overrides = {}
385
393
@@ -422,20 +430,23 @@ def _update_stats_delta_txn(
422
430
)
423
431
424
432
def _upsert_with_additive_relatives_txn (
425
- self , txn , table , keyvalues , absolutes , additive_relatives
426
- ):
433
+ self ,
434
+ txn : LoggingTransaction ,
435
+ table : str ,
436
+ keyvalues : Dict [str , Any ],
437
+ absolutes : Dict [str , Any ],
438
+ additive_relatives : Dict [str , int ],
439
+ ) -> None :
427
440
"""Used to update values in the stats tables.
428
441
429
442
This is basically a slightly convoluted upsert that *adds* to any
430
443
existing rows.
431
444
432
445
Args:
433
- txn
434
- table (str): Table name
435
- keyvalues (dict[str, any]): Row-identifying key values
436
- absolutes (dict[str, any]): Absolute (set) fields
437
- additive_relatives (dict[str, int]): Fields that will be added onto
438
- if existing row present.
446
+ table: Table name
447
+ keyvalues: Row-identifying key values
448
+ absolutes: Absolute (set) fields
449
+ additive_relatives: Fields that will be added onto if existing row present.
439
450
"""
440
451
if self .database_engine .can_native_upsert :
441
452
absolute_updates = [
@@ -491,20 +502,17 @@ def _upsert_with_additive_relatives_txn(
491
502
current_row .update (absolutes )
492
503
self .db_pool .simple_update_one_txn (txn , table , keyvalues , current_row )
493
504
494
- async def _calculate_and_set_initial_state_for_room (
495
- self , room_id : str
496
- ) -> Tuple [dict , dict , int ]:
505
+ async def _calculate_and_set_initial_state_for_room (self , room_id : str ) -> None :
497
506
"""Calculate and insert an entry into room_stats_current.
498
507
499
508
Args:
500
509
room_id: The room ID under calculation.
501
-
502
- Returns:
503
- A tuple of room state, membership counts and stream position.
504
510
"""
505
511
506
- def _fetch_current_state_stats (txn ):
507
- pos = self .get_room_max_stream_ordering ()
512
+ def _fetch_current_state_stats (
513
+ txn : LoggingTransaction ,
514
+ ) -> Tuple [List [str ], Dict [str , int ], int , List [str ], int ]:
515
+ pos = self .get_room_max_stream_ordering () # type: ignore[attr-defined]
508
516
509
517
rows = self .db_pool .simple_select_many_txn (
510
518
txn ,
@@ -524,7 +532,7 @@ def _fetch_current_state_stats(txn):
524
532
retcols = ["event_id" ],
525
533
)
526
534
527
- event_ids = [ row ["event_id" ] for row in rows ]
535
+ event_ids = cast ( List [ str ], [ row ["event_id" ] for row in rows ])
528
536
529
537
txn .execute (
530
538
"""
@@ -544,9 +552,9 @@ def _fetch_current_state_stats(txn):
544
552
(room_id ,),
545
553
)
546
554
547
- ( current_state_events_count ,) = txn .fetchone ()
555
+ current_state_events_count = cast ( Tuple [ int ], txn .fetchone ())[ 0 ]
548
556
549
- users_in_room = self .get_users_in_room_txn (txn , room_id )
557
+ users_in_room = self .get_users_in_room_txn (txn , room_id ) # type: ignore[attr-defined]
550
558
551
559
return (
552
560
event_ids ,
@@ -566,7 +574,7 @@ def _fetch_current_state_stats(txn):
566
574
"get_initial_state_for_room" , _fetch_current_state_stats
567
575
)
568
576
569
- state_event_map = await self .get_events (event_ids , get_prev_content = False )
577
+ state_event_map = await self .get_events (event_ids , get_prev_content = False ) # type: ignore[attr-defined]
570
578
571
579
room_state = {
572
580
"join_rules" : None ,
@@ -622,8 +630,10 @@ def _fetch_current_state_stats(txn):
622
630
},
623
631
)
624
632
625
- async def _calculate_and_set_initial_state_for_user (self , user_id ):
626
- def _calculate_and_set_initial_state_for_user_txn (txn ):
633
+ async def _calculate_and_set_initial_state_for_user (self , user_id : str ) -> None :
634
+ def _calculate_and_set_initial_state_for_user_txn (
635
+ txn : LoggingTransaction ,
636
+ ) -> Tuple [int , int ]:
627
637
pos = self ._get_max_stream_id_in_current_state_deltas_txn (txn )
628
638
629
639
txn .execute (
@@ -634,7 +644,7 @@ def _calculate_and_set_initial_state_for_user_txn(txn):
634
644
""" ,
635
645
(user_id ,),
636
646
)
637
- ( count ,) = txn .fetchone ()
647
+ count = cast ( Tuple [ int ], txn .fetchone ())[ 0 ]
638
648
return count , pos
639
649
640
650
joined_rooms , pos = await self .db_pool .runInteraction (
@@ -678,7 +688,9 @@ async def get_users_media_usage_paginate(
678
688
users that exist given this query
679
689
"""
680
690
681
- def get_users_media_usage_paginate_txn (txn ):
691
+ def get_users_media_usage_paginate_txn (
692
+ txn : LoggingTransaction ,
693
+ ) -> Tuple [List [JsonDict ], int ]:
682
694
filters = []
683
695
args = [self .hs .config .server .server_name ]
684
696
@@ -733,7 +745,7 @@ def get_users_media_usage_paginate_txn(txn):
733
745
sql_base = sql_base ,
734
746
)
735
747
txn .execute (sql , args )
736
- count = txn .fetchone ()[0 ]
748
+ count = cast ( Tuple [ int ], txn .fetchone () )[0 ]
737
749
738
750
sql = """
739
751
SELECT
0 commit comments