|
15 | 15 |
|
16 | 16 | import logging
|
17 | 17 | import re
|
18 |
| -from typing import Any, Dict, List, Mapping, Optional, Pattern, Tuple, Union |
| 18 | +from typing import Any, Dict, List, Mapping, Optional, Pattern, Set, Tuple, Union |
19 | 19 |
|
20 | 20 | from matrix_common.regex import glob_to_regex, to_word_pattern
|
21 | 21 |
|
@@ -120,11 +120,15 @@ def __init__(
|
120 | 120 | room_member_count: int,
|
121 | 121 | sender_power_level: int,
|
122 | 122 | power_levels: Dict[str, Union[int, Dict[str, int]]],
|
| 123 | + relations: Dict[str, Set[Tuple[str, str]]], |
| 124 | + relations_match_enabled: bool, |
123 | 125 | ):
|
124 | 126 | self._event = event
|
125 | 127 | self._room_member_count = room_member_count
|
126 | 128 | self._sender_power_level = sender_power_level
|
127 | 129 | self._power_levels = power_levels
|
| 130 | + self._relations = relations |
| 131 | + self._relations_match_enabled = relations_match_enabled |
128 | 132 |
|
129 | 133 | # Maps strings of e.g. 'content.body' -> event["content"]["body"]
|
130 | 134 | self._value_cache = _flatten_dict(event)
|
@@ -188,7 +192,16 @@ def matches(
|
188 | 192 | return _sender_notification_permission(
|
189 | 193 | self._event, condition, self._sender_power_level, self._power_levels
|
190 | 194 | )
|
| 195 | + elif ( |
| 196 | + condition["kind"] == "org.matrix.msc3772.relation_match" |
| 197 | + and self._relations_match_enabled |
| 198 | + ): |
| 199 | + return self._relation_match(condition, user_id) |
191 | 200 | else:
|
| 201 | + # XXX This looks incorrect -- we have reached an unknown condition |
| 202 | + # kind and are unconditionally returning that it matches. Note |
| 203 | + # that it seems possible to provide a condition to the /pushrules |
| 204 | + # endpoint with an unknown kind, see _rule_tuple_from_request_object. |
192 | 205 | return True
|
193 | 206 |
|
194 | 207 | def _event_match(self, condition: dict, user_id: str) -> bool:
|
@@ -256,6 +269,41 @@ def _contains_display_name(self, display_name: Optional[str]) -> bool:
|
256 | 269 |
|
257 | 270 | return bool(r.search(body))
|
258 | 271 |
|
| 272 | + def _relation_match(self, condition: dict, user_id: str) -> bool: |
| 273 | + """ |
| 274 | + Check an "relation_match" push rule condition. |
| 275 | +
|
| 276 | + Args: |
| 277 | + condition: The "event_match" push rule condition to match. |
| 278 | + user_id: The user's MXID. |
| 279 | +
|
| 280 | + Returns: |
| 281 | + True if the condition matches the event, False otherwise. |
| 282 | + """ |
| 283 | + rel_type = condition.get("rel_type") |
| 284 | + if not rel_type: |
| 285 | + logger.warning("relation_match condition missing rel_type") |
| 286 | + return False |
| 287 | + |
| 288 | + sender_pattern = condition.get("sender") |
| 289 | + if sender_pattern is None: |
| 290 | + sender_type = condition.get("sender_type") |
| 291 | + if sender_type == "user_id": |
| 292 | + sender_pattern = user_id |
| 293 | + type_pattern = condition.get("type") |
| 294 | + |
| 295 | + # If any other relations matches, return True. |
| 296 | + for sender, event_type in self._relations.get(rel_type, ()): |
| 297 | + if sender_pattern and not _glob_matches(sender_pattern, sender): |
| 298 | + continue |
| 299 | + if type_pattern and not _glob_matches(type_pattern, event_type): |
| 300 | + continue |
| 301 | + # All values must have matched. |
| 302 | + return True |
| 303 | + |
| 304 | + # No relations matched. |
| 305 | + return False |
| 306 | + |
259 | 307 |
|
260 | 308 | # Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches
|
261 | 309 | regex_cache: LruCache[Tuple[str, bool, bool], Pattern] = LruCache(
|
|
0 commit comments