Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit d93362d

Browse files
authored
Add a constant for receipt types (m.read). (#11531)
And expand some type hints in the receipts storage module.
1 parent 7ecaa3b commit d93362d

File tree

9 files changed

+87
-45
lines changed

9 files changed

+87
-45
lines changed

changelog.d/11531.misc

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add a receipt types constant for `m.read`.

synapse/api/constants.py

+4
Original file line numberDiff line numberDiff line change
@@ -253,5 +253,9 @@ class GuestAccess:
253253
FORBIDDEN: Final = "forbidden"
254254

255255

256+
class ReceiptTypes:
257+
READ: Final = "m.read"
258+
259+
256260
class ReadReceiptEventFields:
257261
MSC2285_HIDDEN: Final = "org.matrix.msc2285.hidden"

synapse/handlers/receipts.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import logging
1515
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
1616

17-
from synapse.api.constants import ReadReceiptEventFields
17+
from synapse.api.constants import ReadReceiptEventFields, ReceiptTypes
1818
from synapse.appservice import ApplicationService
1919
from synapse.streams import EventSource
2020
from synapse.types import JsonDict, ReadReceipt, UserID, get_domain_from_id
@@ -178,7 +178,7 @@ def filter_out_hidden(events: List[JsonDict], user_id: str) -> List[JsonDict]:
178178

179179
for event_id in content.keys():
180180
event_content = content.get(event_id, {})
181-
m_read = event_content.get("m.read", {})
181+
m_read = event_content.get(ReceiptTypes.READ, {})
182182

183183
# If m_read is missing copy over the original event_content as there is nothing to process here
184184
if not m_read:
@@ -206,7 +206,7 @@ def filter_out_hidden(events: List[JsonDict], user_id: str) -> List[JsonDict]:
206206

207207
# Set new users unless empty
208208
if len(new_users.keys()) > 0:
209-
new_event["content"][event_id] = {"m.read": new_users}
209+
new_event["content"][event_id] = {ReceiptTypes.READ: new_users}
210210

211211
# Append new_event to visible_events unless empty
212212
if len(new_event["content"].keys()) > 0:

synapse/handlers/sync.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
import attr
2929
from prometheus_client import Counter
3030

31-
from synapse.api.constants import AccountDataTypes, EventTypes, Membership
31+
from synapse.api.constants import AccountDataTypes, EventTypes, Membership, ReceiptTypes
3232
from synapse.api.filtering import FilterCollection
3333
from synapse.api.presence import UserPresenceState
3434
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
@@ -1046,7 +1046,7 @@ async def unread_notifs_for_room_id(
10461046
last_unread_event_id = await self.store.get_last_receipt_event_id_for_user(
10471047
user_id=sync_config.user.to_string(),
10481048
room_id=room_id,
1049-
receipt_type="m.read",
1049+
receipt_type=ReceiptTypes.READ,
10501050
)
10511051

10521052
notifs = await self.store.get_unread_event_push_actions_by_room_for_user(

synapse/push/push_tools.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
from typing import Dict
1515

16+
from synapse.api.constants import ReceiptTypes
1617
from synapse.events import EventBase
1718
from synapse.push.presentable_names import calculate_room_name, name_from_member_event
1819
from synapse.storage import Storage
@@ -23,7 +24,7 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
2324
invites = await store.get_invited_rooms_for_local_user(user_id)
2425
joins = await store.get_rooms_for_user(user_id)
2526

26-
my_receipts_by_room = await store.get_receipts_for_user(user_id, "m.read")
27+
my_receipts_by_room = await store.get_receipts_for_user(user_id, ReceiptTypes.READ)
2728

2829
badge = len(invites)
2930

synapse/rest/client/notifications.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import logging
1616
from typing import TYPE_CHECKING, Tuple
1717

18+
from synapse.api.constants import ReceiptTypes
1819
from synapse.events.utils import format_event_for_client_v2_without_room_id
1920
from synapse.http.server import HttpServer
2021
from synapse.http.servlet import RestServlet, parse_integer, parse_string
@@ -54,7 +55,7 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
5455
)
5556

5657
receipts_by_room = await self.store.get_receipts_for_user_with_orderings(
57-
user_id, "m.read"
58+
user_id, ReceiptTypes.READ
5859
)
5960

6061
notif_event_ids = [pa["event_id"] for pa in push_actions]

synapse/rest/client/read_marker.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import logging
1616
from typing import TYPE_CHECKING, Tuple
1717

18-
from synapse.api.constants import ReadReceiptEventFields
18+
from synapse.api.constants import ReadReceiptEventFields, ReceiptTypes
1919
from synapse.api.errors import Codes, SynapseError
2020
from synapse.http.server import HttpServer
2121
from synapse.http.servlet import RestServlet, parse_json_object_from_request
@@ -48,7 +48,7 @@ async def on_POST(
4848
await self.presence_handler.bump_presence_active_time(requester.user)
4949

5050
body = parse_json_object_from_request(request)
51-
read_event_id = body.get("m.read", None)
51+
read_event_id = body.get(ReceiptTypes.READ, None)
5252
hidden = body.get(ReadReceiptEventFields.MSC2285_HIDDEN, False)
5353

5454
if not isinstance(hidden, bool):
@@ -62,7 +62,7 @@ async def on_POST(
6262
if read_event_id:
6363
await self.receipts_handler.received_client_receipt(
6464
room_id,
65-
"m.read",
65+
ReceiptTypes.READ,
6666
user_id=requester.user.to_string(),
6767
event_id=read_event_id,
6868
hidden=hidden,

synapse/rest/client/receipts.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import re
1717
from typing import TYPE_CHECKING, Tuple
1818

19-
from synapse.api.constants import ReadReceiptEventFields
19+
from synapse.api.constants import ReadReceiptEventFields, ReceiptTypes
2020
from synapse.api.errors import Codes, SynapseError
2121
from synapse.http import get_request_user_agent
2222
from synapse.http.server import HttpServer
@@ -53,7 +53,7 @@ async def on_POST(
5353
) -> Tuple[int, JsonDict]:
5454
requester = await self.auth.get_user_by_req(request)
5555

56-
if receipt_type != "m.read":
56+
if receipt_type != ReceiptTypes.READ:
5757
raise SynapseError(400, "Receipt type must be 'm.read'")
5858

5959
# Do not allow older SchildiChat and Element Android clients (prior to Element/1.[012].x) to send an empty body.

synapse/storage/databases/main/receipts.py

+68-33
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,25 @@
1414
# limitations under the License.
1515

1616
import logging
17-
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
17+
from typing import (
18+
TYPE_CHECKING,
19+
Any,
20+
Collection,
21+
Dict,
22+
Iterable,
23+
List,
24+
Optional,
25+
Set,
26+
Tuple,
27+
)
1828

1929
from twisted.internet import defer
2030

31+
from synapse.api.constants import ReceiptTypes
2132
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
2233
from synapse.replication.tcp.streams import ReceiptsStream
2334
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
24-
from synapse.storage.database import DatabasePool
35+
from synapse.storage.database import DatabasePool, LoggingTransaction
2536
from synapse.storage.engines import PostgresEngine
2637
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
2738
from synapse.types import JsonDict
@@ -78,17 +89,13 @@ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
7889
"ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()
7990
)
8091

81-
def get_max_receipt_stream_id(self):
82-
"""Get the current max stream ID for receipts stream
83-
84-
Returns:
85-
int
86-
"""
92+
def get_max_receipt_stream_id(self) -> int:
93+
"""Get the current max stream ID for receipts stream"""
8794
return self._receipts_id_gen.get_current_token()
8895

8996
@cached()
90-
async def get_users_with_read_receipts_in_room(self, room_id):
91-
receipts = await self.get_receipts_for_room(room_id, "m.read")
97+
async def get_users_with_read_receipts_in_room(self, room_id: str) -> Set[str]:
98+
receipts = await self.get_receipts_for_room(room_id, ReceiptTypes.READ)
9299
return {r["user_id"] for r in receipts}
93100

94101
@cached(num_args=2)
@@ -119,7 +126,9 @@ async def get_last_receipt_event_id_for_user(
119126
)
120127

121128
@cached(num_args=2)
122-
async def get_receipts_for_user(self, user_id, receipt_type):
129+
async def get_receipts_for_user(
130+
self, user_id: str, receipt_type: str
131+
) -> Dict[str, str]:
123132
rows = await self.db_pool.simple_select_list(
124133
table="receipts_linearized",
125134
keyvalues={"user_id": user_id, "receipt_type": receipt_type},
@@ -129,8 +138,10 @@ async def get_receipts_for_user(self, user_id, receipt_type):
129138

130139
return {row["room_id"]: row["event_id"] for row in rows}
131140

132-
async def get_receipts_for_user_with_orderings(self, user_id, receipt_type):
133-
def f(txn):
141+
async def get_receipts_for_user_with_orderings(
142+
self, user_id: str, receipt_type: str
143+
) -> JsonDict:
144+
def f(txn: LoggingTransaction) -> List[Tuple[str, str, int, int]]:
134145
sql = (
135146
"SELECT rl.room_id, rl.event_id,"
136147
" e.topological_ordering, e.stream_ordering"
@@ -209,10 +220,10 @@ async def get_linearized_receipts_for_room(
209220
@cached(num_args=3, tree=True)
210221
async def _get_linearized_receipts_for_room(
211222
self, room_id: str, to_key: int, from_key: Optional[int] = None
212-
) -> List[dict]:
223+
) -> List[JsonDict]:
213224
"""See get_linearized_receipts_for_room"""
214225

215-
def f(txn):
226+
def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
216227
if from_key:
217228
sql = (
218229
"SELECT * FROM receipts_linearized WHERE"
@@ -250,11 +261,13 @@ def f(txn):
250261
list_name="room_ids",
251262
num_args=3,
252263
)
253-
async def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
264+
async def _get_linearized_receipts_for_rooms(
265+
self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None
266+
) -> Dict[str, List[JsonDict]]:
254267
if not room_ids:
255268
return {}
256269

257-
def f(txn):
270+
def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
258271
if from_key:
259272
sql = """
260273
SELECT * FROM receipts_linearized WHERE
@@ -323,7 +336,7 @@ async def get_linearized_receipts_for_all_rooms(
323336
A dictionary of roomids to a list of receipts.
324337
"""
325338

326-
def f(txn):
339+
def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
327340
if from_key:
328341
sql = """
329342
SELECT * FROM receipts_linearized WHERE
@@ -379,7 +392,7 @@ async def get_users_sent_receipts_between(
379392
if last_id == current_id:
380393
return defer.succeed([])
381394

382-
def _get_users_sent_receipts_between_txn(txn):
395+
def _get_users_sent_receipts_between_txn(txn: LoggingTransaction) -> List[str]:
383396
sql = """
384397
SELECT DISTINCT user_id FROM receipts_linearized
385398
WHERE ? < stream_id AND stream_id <= ?
@@ -419,7 +432,9 @@ async def get_all_updated_receipts(
419432
if last_id == current_id:
420433
return [], current_id, False
421434

422-
def get_all_updated_receipts_txn(txn):
435+
def get_all_updated_receipts_txn(
436+
txn: LoggingTransaction,
437+
) -> Tuple[List[Tuple[int, list]], int, bool]:
423438
sql = """
424439
SELECT stream_id, room_id, receipt_type, user_id, event_id, data
425440
FROM receipts_linearized
@@ -446,8 +461,8 @@ def get_all_updated_receipts_txn(txn):
446461

447462
def _invalidate_get_users_with_receipts_in_room(
448463
self, room_id: str, receipt_type: str, user_id: str
449-
):
450-
if receipt_type != "m.read":
464+
) -> None:
465+
if receipt_type != ReceiptTypes.READ:
451466
return
452467

453468
res = self.get_users_with_read_receipts_in_room.cache.get_immediate(
@@ -461,7 +476,9 @@ def _invalidate_get_users_with_receipts_in_room(
461476

462477
self.get_users_with_read_receipts_in_room.invalidate((room_id,))
463478

464-
def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
479+
def invalidate_caches_for_receipt(
480+
self, room_id: str, receipt_type: str, user_id: str
481+
) -> None:
465482
self.get_receipts_for_user.invalidate((user_id, receipt_type))
466483
self._get_linearized_receipts_for_room.invalidate((room_id,))
467484
self.get_last_receipt_event_id_for_user.invalidate(
@@ -482,11 +499,18 @@ def process_replication_rows(self, stream_name, instance_name, token, rows):
482499
return super().process_replication_rows(stream_name, instance_name, token, rows)
483500

484501
def insert_linearized_receipt_txn(
485-
self, txn, room_id, receipt_type, user_id, event_id, data, stream_id
486-
):
502+
self,
503+
txn: LoggingTransaction,
504+
room_id: str,
505+
receipt_type: str,
506+
user_id: str,
507+
event_id: str,
508+
data: JsonDict,
509+
stream_id: int,
510+
) -> Optional[int]:
487511
"""Inserts a read-receipt into the database if it's newer than the current RR
488512
489-
Returns: int|None
513+
Returns:
490514
None if the RR is older than the current RR
491515
otherwise, the rx timestamp of the event that the RR corresponds to
492516
(or 0 if the event is unknown)
@@ -550,7 +574,7 @@ def insert_linearized_receipt_txn(
550574
lock=False,
551575
)
552576

553-
if receipt_type == "m.read" and stream_ordering is not None:
577+
if receipt_type == ReceiptTypes.READ and stream_ordering is not None:
554578
self._remove_old_push_actions_before_txn(
555579
txn, room_id=room_id, user_id=user_id, stream_ordering=stream_ordering
556580
)
@@ -580,7 +604,7 @@ async def insert_receipt(
580604
else:
581605
# we need to points in graph -> linearized form.
582606
# TODO: Make this better.
583-
def graph_to_linear(txn):
607+
def graph_to_linear(txn: LoggingTransaction) -> str:
584608
clause, args = make_in_list_sql_clause(
585609
self.database_engine, "event_id", event_ids
586610
)
@@ -634,11 +658,16 @@ def graph_to_linear(txn):
634658
return stream_id, max_persisted_id
635659

636660
async def insert_graph_receipt(
637-
self, room_id, receipt_type, user_id, event_ids, data
638-
):
661+
self,
662+
room_id: str,
663+
receipt_type: str,
664+
user_id: str,
665+
event_ids: List[str],
666+
data: JsonDict,
667+
) -> None:
639668
assert self._can_write_to_receipts
640669

641-
return await self.db_pool.runInteraction(
670+
await self.db_pool.runInteraction(
642671
"insert_graph_receipt",
643672
self.insert_graph_receipt_txn,
644673
room_id,
@@ -649,8 +678,14 @@ async def insert_graph_receipt(
649678
)
650679

651680
def insert_graph_receipt_txn(
652-
self, txn, room_id, receipt_type, user_id, event_ids, data
653-
):
681+
self,
682+
txn: LoggingTransaction,
683+
room_id: str,
684+
receipt_type: str,
685+
user_id: str,
686+
event_ids: List[str],
687+
data: JsonDict,
688+
) -> None:
654689
assert self._can_write_to_receipts
655690

656691
txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type))

0 commit comments

Comments
 (0)