14
14
# limitations under the License.
15
15
16
16
import logging
17
- from typing import TYPE_CHECKING , Any , Dict , Iterable , List , Optional , Tuple
17
+ from typing import (
18
+ TYPE_CHECKING ,
19
+ Any ,
20
+ Collection ,
21
+ Dict ,
22
+ Iterable ,
23
+ List ,
24
+ Optional ,
25
+ Set ,
26
+ Tuple ,
27
+ )
18
28
19
29
from twisted .internet import defer
20
30
31
+ from synapse .api .constants import ReceiptTypes
21
32
from synapse .replication .slave .storage ._slaved_id_tracker import SlavedIdTracker
22
33
from synapse .replication .tcp .streams import ReceiptsStream
23
34
from synapse .storage ._base import SQLBaseStore , db_to_json , make_in_list_sql_clause
24
- from synapse .storage .database import DatabasePool
35
+ from synapse .storage .database import DatabasePool , LoggingTransaction
25
36
from synapse .storage .engines import PostgresEngine
26
37
from synapse .storage .util .id_generators import MultiWriterIdGenerator , StreamIdGenerator
27
38
from synapse .types import JsonDict
@@ -78,17 +89,13 @@ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
78
89
"ReceiptsRoomChangeCache" , self .get_max_receipt_stream_id ()
79
90
)
80
91
81
- def get_max_receipt_stream_id (self ):
82
- """Get the current max stream ID for receipts stream
83
-
84
- Returns:
85
- int
86
- """
92
+ def get_max_receipt_stream_id (self ) -> int :
93
+ """Get the current max stream ID for receipts stream"""
87
94
return self ._receipts_id_gen .get_current_token ()
88
95
89
96
@cached ()
90
- async def get_users_with_read_receipts_in_room (self , room_id ) :
91
- receipts = await self .get_receipts_for_room (room_id , "m.read" )
97
+ async def get_users_with_read_receipts_in_room (self , room_id : str ) -> Set [ str ] :
98
+ receipts = await self .get_receipts_for_room (room_id , ReceiptTypes . READ )
92
99
return {r ["user_id" ] for r in receipts }
93
100
94
101
@cached (num_args = 2 )
@@ -119,7 +126,9 @@ async def get_last_receipt_event_id_for_user(
119
126
)
120
127
121
128
@cached (num_args = 2 )
122
- async def get_receipts_for_user (self , user_id , receipt_type ):
129
+ async def get_receipts_for_user (
130
+ self , user_id : str , receipt_type : str
131
+ ) -> Dict [str , str ]:
123
132
rows = await self .db_pool .simple_select_list (
124
133
table = "receipts_linearized" ,
125
134
keyvalues = {"user_id" : user_id , "receipt_type" : receipt_type },
@@ -129,8 +138,10 @@ async def get_receipts_for_user(self, user_id, receipt_type):
129
138
130
139
return {row ["room_id" ]: row ["event_id" ] for row in rows }
131
140
132
- async def get_receipts_for_user_with_orderings (self , user_id , receipt_type ):
133
- def f (txn ):
141
+ async def get_receipts_for_user_with_orderings (
142
+ self , user_id : str , receipt_type : str
143
+ ) -> JsonDict :
144
+ def f (txn : LoggingTransaction ) -> List [Tuple [str , str , int , int ]]:
134
145
sql = (
135
146
"SELECT rl.room_id, rl.event_id,"
136
147
" e.topological_ordering, e.stream_ordering"
@@ -209,10 +220,10 @@ async def get_linearized_receipts_for_room(
209
220
@cached (num_args = 3 , tree = True )
210
221
async def _get_linearized_receipts_for_room (
211
222
self , room_id : str , to_key : int , from_key : Optional [int ] = None
212
- ) -> List [dict ]:
223
+ ) -> List [JsonDict ]:
213
224
"""See get_linearized_receipts_for_room"""
214
225
215
- def f (txn ) :
226
+ def f (txn : LoggingTransaction ) -> List [ Dict [ str , Any ]] :
216
227
if from_key :
217
228
sql = (
218
229
"SELECT * FROM receipts_linearized WHERE"
@@ -250,11 +261,13 @@ def f(txn):
250
261
list_name = "room_ids" ,
251
262
num_args = 3 ,
252
263
)
253
- async def _get_linearized_receipts_for_rooms (self , room_ids , to_key , from_key = None ):
264
+ async def _get_linearized_receipts_for_rooms (
265
+ self , room_ids : Collection [str ], to_key : int , from_key : Optional [int ] = None
266
+ ) -> Dict [str , List [JsonDict ]]:
254
267
if not room_ids :
255
268
return {}
256
269
257
- def f (txn ) :
270
+ def f (txn : LoggingTransaction ) -> List [ Dict [ str , Any ]] :
258
271
if from_key :
259
272
sql = """
260
273
SELECT * FROM receipts_linearized WHERE
@@ -323,7 +336,7 @@ async def get_linearized_receipts_for_all_rooms(
323
336
A dictionary of roomids to a list of receipts.
324
337
"""
325
338
326
- def f (txn ) :
339
+ def f (txn : LoggingTransaction ) -> List [ Dict [ str , Any ]] :
327
340
if from_key :
328
341
sql = """
329
342
SELECT * FROM receipts_linearized WHERE
@@ -379,7 +392,7 @@ async def get_users_sent_receipts_between(
379
392
if last_id == current_id :
380
393
return defer .succeed ([])
381
394
382
- def _get_users_sent_receipts_between_txn (txn ) :
395
+ def _get_users_sent_receipts_between_txn (txn : LoggingTransaction ) -> List [ str ] :
383
396
sql = """
384
397
SELECT DISTINCT user_id FROM receipts_linearized
385
398
WHERE ? < stream_id AND stream_id <= ?
@@ -419,7 +432,9 @@ async def get_all_updated_receipts(
419
432
if last_id == current_id :
420
433
return [], current_id , False
421
434
422
- def get_all_updated_receipts_txn (txn ):
435
+ def get_all_updated_receipts_txn (
436
+ txn : LoggingTransaction ,
437
+ ) -> Tuple [List [Tuple [int , list ]], int , bool ]:
423
438
sql = """
424
439
SELECT stream_id, room_id, receipt_type, user_id, event_id, data
425
440
FROM receipts_linearized
@@ -446,8 +461,8 @@ def get_all_updated_receipts_txn(txn):
446
461
447
462
def _invalidate_get_users_with_receipts_in_room (
448
463
self , room_id : str , receipt_type : str , user_id : str
449
- ):
450
- if receipt_type != "m.read" :
464
+ ) -> None :
465
+ if receipt_type != ReceiptTypes . READ :
451
466
return
452
467
453
468
res = self .get_users_with_read_receipts_in_room .cache .get_immediate (
@@ -461,7 +476,9 @@ def _invalidate_get_users_with_receipts_in_room(
461
476
462
477
self .get_users_with_read_receipts_in_room .invalidate ((room_id ,))
463
478
464
- def invalidate_caches_for_receipt (self , room_id , receipt_type , user_id ):
479
+ def invalidate_caches_for_receipt (
480
+ self , room_id : str , receipt_type : str , user_id : str
481
+ ) -> None :
465
482
self .get_receipts_for_user .invalidate ((user_id , receipt_type ))
466
483
self ._get_linearized_receipts_for_room .invalidate ((room_id ,))
467
484
self .get_last_receipt_event_id_for_user .invalidate (
@@ -482,11 +499,18 @@ def process_replication_rows(self, stream_name, instance_name, token, rows):
482
499
return super ().process_replication_rows (stream_name , instance_name , token , rows )
483
500
484
501
def insert_linearized_receipt_txn (
485
- self , txn , room_id , receipt_type , user_id , event_id , data , stream_id
486
- ):
502
+ self ,
503
+ txn : LoggingTransaction ,
504
+ room_id : str ,
505
+ receipt_type : str ,
506
+ user_id : str ,
507
+ event_id : str ,
508
+ data : JsonDict ,
509
+ stream_id : int ,
510
+ ) -> Optional [int ]:
487
511
"""Inserts a read-receipt into the database if it's newer than the current RR
488
512
489
- Returns: int|None
513
+ Returns:
490
514
None if the RR is older than the current RR
491
515
otherwise, the rx timestamp of the event that the RR corresponds to
492
516
(or 0 if the event is unknown)
@@ -550,7 +574,7 @@ def insert_linearized_receipt_txn(
550
574
lock = False ,
551
575
)
552
576
553
- if receipt_type == "m.read" and stream_ordering is not None :
577
+ if receipt_type == ReceiptTypes . READ and stream_ordering is not None :
554
578
self ._remove_old_push_actions_before_txn (
555
579
txn , room_id = room_id , user_id = user_id , stream_ordering = stream_ordering
556
580
)
@@ -580,7 +604,7 @@ async def insert_receipt(
580
604
else :
581
605
# we need to points in graph -> linearized form.
582
606
# TODO: Make this better.
583
- def graph_to_linear (txn ) :
607
+ def graph_to_linear (txn : LoggingTransaction ) -> str :
584
608
clause , args = make_in_list_sql_clause (
585
609
self .database_engine , "event_id" , event_ids
586
610
)
@@ -634,11 +658,16 @@ def graph_to_linear(txn):
634
658
return stream_id , max_persisted_id
635
659
636
660
async def insert_graph_receipt (
637
- self , room_id , receipt_type , user_id , event_ids , data
638
- ):
661
+ self ,
662
+ room_id : str ,
663
+ receipt_type : str ,
664
+ user_id : str ,
665
+ event_ids : List [str ],
666
+ data : JsonDict ,
667
+ ) -> None :
639
668
assert self ._can_write_to_receipts
640
669
641
- return await self .db_pool .runInteraction (
670
+ await self .db_pool .runInteraction (
642
671
"insert_graph_receipt" ,
643
672
self .insert_graph_receipt_txn ,
644
673
room_id ,
@@ -649,8 +678,14 @@ async def insert_graph_receipt(
649
678
)
650
679
651
680
def insert_graph_receipt_txn (
652
- self , txn , room_id , receipt_type , user_id , event_ids , data
653
- ):
681
+ self ,
682
+ txn : LoggingTransaction ,
683
+ room_id : str ,
684
+ receipt_type : str ,
685
+ user_id : str ,
686
+ event_ids : List [str ],
687
+ data : JsonDict ,
688
+ ) -> None :
654
689
assert self ._can_write_to_receipts
655
690
656
691
txn .call_after (self .get_receipts_for_room .invalidate , (room_id , receipt_type ))
0 commit comments