Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 20117c4

Browse files
committed
Consolidate logic for parsing relations.
1 parent 5c00151 commit 20117c4

File tree

6 files changed

+86
-54
lines changed

6 files changed

+86
-54
lines changed

changelog.d/12693.misc

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Consolidate parsing of relation information from events.

synapse/events/__init__.py

+44
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# limitations under the License.
1616

1717
import abc
18+
import collections.abc
1819
import os
1920
from typing import (
2021
TYPE_CHECKING,
@@ -32,9 +33,11 @@
3233
overload,
3334
)
3435

36+
import attr
3537
from typing_extensions import Literal
3638
from unpaddedbase64 import encode_base64
3739

40+
from synapse.api.constants import RelationTypes
3841
from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions
3942
from synapse.types import JsonDict, RoomStreamToken
4043
from synapse.util.caches import intern_dict
@@ -287,6 +290,17 @@ def is_historical(self) -> bool:
287290
return self._dict.get("historical", False)
288291

289292

293+
@attr.s(slots=True, frozen=True, auto_attribs=True)
294+
class _EventRelation:
295+
# The target event of the relation.
296+
parent_id: str
297+
# The relation type.
298+
rel_type: str
299+
# The aggregation key. Will be None if the rel_type is not m.annotation or is
300+
# not a string.
301+
aggregation_key: Optional[str]
302+
303+
290304
class EventBase(metaclass=abc.ABCMeta):
291305
@property
292306
@abc.abstractmethod
@@ -415,6 +429,36 @@ def auth_event_ids(self) -> Sequence[str]:
415429
"""
416430
return [e for e, _ in self._dict["auth_events"]]
417431

432+
def relation(self) -> Optional[_EventRelation]:
433+
"""
434+
Parse the event's relation information.
435+
436+
Returns:
437+
The event relation information, if it is valid. None, otherwise.
438+
"""
439+
relation = self.content.get("m.relates_to")
440+
if not relation or not isinstance(relation, collections.abc.Mapping):
441+
# No relation information.
442+
return None
443+
444+
# Relations must have a type and parent event ID.
445+
rel_type = relation.get("rel_type")
446+
if not isinstance(rel_type, str):
447+
return None
448+
449+
parent_id = relation.get("event_id")
450+
if not isinstance(parent_id, str):
451+
return None
452+
453+
# Annotations have a key field.
454+
aggregation_key = None
455+
if rel_type == RelationTypes.ANNOTATION:
456+
aggregation_key = relation.get("key")
457+
if not isinstance(aggregation_key, str):
458+
aggregation_key = None
459+
460+
return _EventRelation(parent_id, rel_type, aggregation_key)
461+
418462
def freeze(self) -> None:
419463
"""'Freeze' the event dict, so it cannot be modified by accident"""
420464

synapse/handlers/message.py

+11-17
Original file line numberDiff line numberDiff line change
@@ -1056,20 +1056,11 @@ async def _validate_event_relation(self, event: EventBase) -> None:
10561056
SynapseError if the event is invalid.
10571057
"""
10581058

1059-
relation = event.content.get("m.relates_to")
1059+
relation = event.relation()
10601060
if not relation:
10611061
return
10621062

1063-
relation_type = relation.get("rel_type")
1064-
if not relation_type:
1065-
return
1066-
1067-
# Ensure the parent is real.
1068-
relates_to = relation.get("event_id")
1069-
if not relates_to:
1070-
return
1071-
1072-
parent_event = await self.store.get_event(relates_to, allow_none=True)
1063+
parent_event = await self.store.get_event(relation.parent_id, allow_none=True)
10731064
if parent_event:
10741065
# And in the same room.
10751066
if parent_event.room_id != event.room_id:
@@ -1078,28 +1069,31 @@ async def _validate_event_relation(self, event: EventBase) -> None:
10781069
else:
10791070
# There must be some reason that the client knows the event exists,
10801071
# see if there are existing relations. If so, assume everything is fine.
1081-
if not await self.store.event_is_target_of_relation(relates_to):
1072+
if not await self.store.event_is_target_of_relation(relation.parent_id):
10821073
# Otherwise, the client can't know about the parent event!
10831074
raise SynapseError(400, "Can't send relation to unknown event")
10841075

10851076
# If this event is an annotation then we check that that the sender
10861077
# can't annotate the same way twice (e.g. stops users from liking an
10871078
# event multiple times).
1088-
if relation_type == RelationTypes.ANNOTATION:
1089-
aggregation_key = relation["key"]
1079+
if relation.rel_type == RelationTypes.ANNOTATION:
1080+
aggregation_key = relation.aggregation_key
1081+
1082+
if aggregation_key is None:
1083+
raise SynapseError(400, "Missing aggregation key")
10901084

10911085
if len(aggregation_key) > 500:
10921086
raise SynapseError(400, "Aggregation key is too long")
10931087

10941088
already_exists = await self.store.has_user_annotated_event(
1095-
relates_to, event.type, aggregation_key, event.sender
1089+
relation.parent_id, event.type, aggregation_key, event.sender
10961090
)
10971091
if already_exists:
10981092
raise SynapseError(400, "Can't send same reaction twice")
10991093

11001094
# Don't attempt to start a thread if the parent event is a relation.
1101-
elif relation_type == RelationTypes.THREAD:
1102-
if await self.store.event_includes_relation(relates_to):
1095+
elif relation.rel_type == RelationTypes.THREAD:
1096+
if await self.store.event_includes_relation(relation.parent_id):
11031097
raise SynapseError(
11041098
400, "Cannot start threads from an event with a relation"
11051099
)

synapse/handlers/relations.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import collections.abc
1514
import logging
1615
from typing import (
1716
TYPE_CHECKING,
@@ -373,20 +372,21 @@ async def get_bundled_aggregations(
373372
if event.is_state():
374373
continue
375374

376-
relates_to = event.content.get("m.relates_to")
377-
relation_type = None
378-
if isinstance(relates_to, collections.abc.Mapping):
379-
relation_type = relates_to.get("rel_type")
375+
relates_to = event.relation()
376+
if relates_to:
380377
# An event which is a replacement (ie edit) or annotation (ie,
381378
# reaction) may not have any other event related to it.
382-
if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE):
379+
if relates_to.rel_type in (
380+
RelationTypes.ANNOTATION,
381+
RelationTypes.REPLACE,
382+
):
383383
continue
384384

385+
# Track the event's relation information for later.
386+
relations_by_id[event.event_id] = relates_to.rel_type
387+
385388
# The event should get bundled aggregations.
386389
events_by_id[event.event_id] = event
387-
# Track the event's relation information for later.
388-
if isinstance(relation_type, str):
389-
relations_by_id[event.event_id] = relation_type
390390

391391
# event ID -> bundled aggregation in non-serialized form.
392392
results: Dict[str, BundledAggregations] = {}

synapse/push/bulk_push_rule_evaluator.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,8 @@ def _should_count_as_unread(event: EventBase, context: EventContext) -> bool:
7777
return False
7878

7979
# Exclude edits.
80-
relates_to = event.content.get("m.relates_to", {})
81-
if relates_to.get("rel_type") == RelationTypes.REPLACE:
80+
relates_to = event.relation()
81+
if relates_to and relates_to.rel_type == RelationTypes.REPLACE:
8282
return False
8383

8484
# Mark events that have a non-empty string body as unread.

synapse/storage/databases/main/events.py

+19-26
Original file line numberDiff line numberDiff line change
@@ -1815,52 +1815,45 @@ def _handle_event_relations(
18151815
txn: The current database transaction.
18161816
event: The event which might have relations.
18171817
"""
1818-
relation = event.content.get("m.relates_to")
1818+
relation = event.relation()
18191819
if not relation:
1820-
# No relations
1820+
# No relation, nothing to do.
18211821
return
18221822

1823-
# Relations must have a type and parent event ID.
1824-
rel_type = relation.get("rel_type")
1825-
if not isinstance(rel_type, str):
1826-
return
1827-
1828-
parent_id = relation.get("event_id")
1829-
if not isinstance(parent_id, str):
1830-
return
1831-
1832-
# Annotations have a key field.
1833-
aggregation_key = None
1834-
if rel_type == RelationTypes.ANNOTATION:
1835-
aggregation_key = relation.get("key")
1836-
18371823
self.db_pool.simple_insert_txn(
18381824
txn,
18391825
table="event_relations",
18401826
values={
18411827
"event_id": event.event_id,
1842-
"relates_to_id": parent_id,
1843-
"relation_type": rel_type,
1844-
"aggregation_key": aggregation_key,
1828+
"relates_to_id": relation.parent_id,
1829+
"relation_type": relation.rel_type,
1830+
"aggregation_key": relation.aggregation_key,
18451831
},
18461832
)
18471833

1848-
txn.call_after(self.store.get_relations_for_event.invalidate, (parent_id,))
18491834
txn.call_after(
1850-
self.store.get_aggregation_groups_for_event.invalidate, (parent_id,)
1835+
self.store.get_relations_for_event.invalidate, (relation.parent_id,)
1836+
)
1837+
txn.call_after(
1838+
self.store.get_aggregation_groups_for_event.invalidate,
1839+
(relation.parent_id,),
18511840
)
18521841

1853-
if rel_type == RelationTypes.REPLACE:
1854-
txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,))
1842+
if relation.rel_type == RelationTypes.REPLACE:
1843+
txn.call_after(
1844+
self.store.get_applicable_edit.invalidate, (relation.parent_id,)
1845+
)
18551846

1856-
if rel_type == RelationTypes.THREAD:
1857-
txn.call_after(self.store.get_thread_summary.invalidate, (parent_id,))
1847+
if relation.rel_type == RelationTypes.THREAD:
1848+
txn.call_after(
1849+
self.store.get_thread_summary.invalidate, (relation.parent_id,)
1850+
)
18581851
# It should be safe to only invalidate the cache if the user has not
18591852
# previously participated in the thread, but that's difficult (and
18601853
# potentially error-prone) so it is always invalidated.
18611854
txn.call_after(
18621855
self.store.get_thread_participated.invalidate,
1863-
(parent_id, event.sender),
1856+
(relation.parent_id, event.sender),
18641857
)
18651858

18661859
def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase):

0 commit comments

Comments
 (0)