Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 71e8afe

Browse files
authored
Update EventContext get_current_event_ids and get_prev_event_ids to accept state filters and update calls where possible (#12791)
1 parent 2be5a2b commit 71e8afe

File tree

10 files changed

+65
-18
lines changed

10 files changed

+65
-18
lines changed

Diff for: changelog.d/12791.misc

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Update EventContext `get_current_event_ids` and `get_prev_event_ids` to accept state filters and update calls where possible.

Diff for: synapse/events/snapshot.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
if TYPE_CHECKING:
2525
from synapse.storage import Storage
2626
from synapse.storage.databases.main import DataStore
27+
from synapse.storage.state import StateFilter
2728

2829

2930
@attr.s(slots=True, auto_attribs=True)
@@ -196,14 +197,19 @@ def state_group(self) -> Optional[int]:
196197

197198
return self._state_group
198199

199-
async def get_current_state_ids(self) -> Optional[StateMap[str]]:
200+
async def get_current_state_ids(
201+
self, state_filter: Optional["StateFilter"] = None
202+
) -> Optional[StateMap[str]]:
200203
"""
201204
Gets the room state map, including this event - ie, the state in ``state_group``
202205
203206
It is an error to access this for a rejected event, since rejected state should
204207
not make it into the room state. This method will raise an exception if
205208
``rejected`` is set.
206209
210+
Arg:
211+
state_filter: specifies the type of state event to fetch from DB, example: EventTypes.JoinRules
212+
207213
Returns:
208214
Returns None if state_group is None, which happens when the associated
209215
event is an outlier.
@@ -216,20 +222,25 @@ async def get_current_state_ids(self) -> Optional[StateMap[str]]:
216222

217223
assert self._state_delta_due_to_event is not None
218224

219-
prev_state_ids = await self.get_prev_state_ids()
225+
prev_state_ids = await self.get_prev_state_ids(state_filter)
220226

221227
if self._state_delta_due_to_event:
222228
prev_state_ids = dict(prev_state_ids)
223229
prev_state_ids.update(self._state_delta_due_to_event)
224230

225231
return prev_state_ids
226232

227-
async def get_prev_state_ids(self) -> StateMap[str]:
233+
async def get_prev_state_ids(
234+
self, state_filter: Optional["StateFilter"] = None
235+
) -> StateMap[str]:
228236
"""
229237
Gets the room state map, excluding this event.
230238
231239
For a non-state event, this will be the same as get_current_state_ids().
232240
241+
Args:
242+
state_filter: specifies the type of state event to fetch from DB, example: EventTypes.JoinRules
243+
233244
Returns:
234245
Returns {} if state_group is None, which happens when the associated
235246
event is an outlier.
@@ -239,7 +250,7 @@ async def get_prev_state_ids(self) -> StateMap[str]:
239250
"""
240251
assert self.state_group_before_event is not None
241252
return await self._storage.state.get_state_ids_for_group(
242-
self.state_group_before_event
253+
self.state_group_before_event, state_filter
243254
)
244255

245256

Diff for: synapse/handlers/federation.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
ReplicationStoreRoomOnOutlierMembershipRestServlet,
5555
)
5656
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
57+
from synapse.storage.state import StateFilter
5758
from synapse.types import JsonDict, StateMap, get_domain_from_id
5859
from synapse.util.async_helpers import Linearizer
5960
from synapse.util.retryutils import NotRetryingDestination
@@ -1259,7 +1260,9 @@ async def add_display_name_to_third_party_invite(
12591260
event.content["third_party_invite"]["signed"]["token"],
12601261
)
12611262
original_invite = None
1262-
prev_state_ids = await context.get_prev_state_ids()
1263+
prev_state_ids = await context.get_prev_state_ids(
1264+
StateFilter.from_types([(EventTypes.ThirdPartyInvite, None)])
1265+
)
12631266
original_invite_id = prev_state_ids.get(key)
12641267
if original_invite_id:
12651268
original_invite = await self.store.get_event(
@@ -1308,7 +1311,9 @@ async def _check_signature(self, event: EventBase, context: EventContext) -> Non
13081311
signed = event.content["third_party_invite"]["signed"]
13091312
token = signed["token"]
13101313

1311-
prev_state_ids = await context.get_prev_state_ids()
1314+
prev_state_ids = await context.get_prev_state_ids(
1315+
StateFilter.from_types([(EventTypes.ThirdPartyInvite, None)])
1316+
)
13121317
invite_event_id = prev_state_ids.get((EventTypes.ThirdPartyInvite, token))
13131318

13141319
invite_event = None

Diff for: synapse/handlers/federation_event.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
from prometheus_client import Counter
3232

33+
from synapse import event_auth
3334
from synapse.api.constants import (
3435
EventContentFields,
3536
EventTypes,
@@ -63,6 +64,7 @@
6364
)
6465
from synapse.state import StateResolutionStore
6566
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
67+
from synapse.storage.state import StateFilter
6668
from synapse.types import (
6769
PersistedEventPosition,
6870
RoomStreamToken,
@@ -1500,7 +1502,11 @@ async def _check_event_auth(
15001502
return context
15011503

15021504
# now check auth against what we think the auth events *should* be.
1503-
prev_state_ids = await context.get_prev_state_ids()
1505+
event_types = event_auth.auth_types_for_event(event.room_version, event)
1506+
prev_state_ids = await context.get_prev_state_ids(
1507+
StateFilter.from_types(event_types)
1508+
)
1509+
15041510
auth_events_ids = self._event_auth_handler.compute_auth_events(
15051511
event, prev_state_ids, for_verification=True
15061512
)

Diff for: synapse/handlers/message.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -634,7 +634,9 @@ async def create_event(
634634
# federation as well as those created locally. As of room v3, aliases events
635635
# can be created by users that are not in the room, therefore we have to
636636
# tolerate them in event_auth.check().
637-
prev_state_ids = await context.get_prev_state_ids()
637+
prev_state_ids = await context.get_prev_state_ids(
638+
StateFilter.from_types([(EventTypes.Member, None)])
639+
)
638640
prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender))
639641
prev_event = (
640642
await self.store.get_event(prev_event_id, allow_none=True)
@@ -761,7 +763,9 @@ async def deduplicate_state_event(
761763
# This can happen due to out of band memberships
762764
return None
763765

764-
prev_state_ids = await context.get_prev_state_ids()
766+
prev_state_ids = await context.get_prev_state_ids(
767+
StateFilter.from_types([(event.type, None)])
768+
)
765769
prev_event_id = prev_state_ids.get((event.type, event.state_key))
766770
if not prev_event_id:
767771
return None
@@ -1547,7 +1551,11 @@ async def persist_and_notify_client_event(
15471551
"Redacting MSC2716 events is not supported in this room version",
15481552
)
15491553

1550-
prev_state_ids = await context.get_prev_state_ids()
1554+
event_types = event_auth.auth_types_for_event(event.room_version, event)
1555+
prev_state_ids = await context.get_prev_state_ids(
1556+
StateFilter.from_types(event_types)
1557+
)
1558+
15511559
auth_events_ids = self._event_auth_handler.compute_auth_events(
15521560
event, prev_state_ids, for_verification=True
15531561
)

Diff for: synapse/handlers/room.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,10 @@ async def _upgrade_room(
303303
context=tombstone_context,
304304
)
305305

306-
old_room_state = await tombstone_context.get_current_state_ids()
306+
state_filter = StateFilter.from_types(
307+
[(EventTypes.CanonicalAlias, ""), (EventTypes.PowerLevels, "")]
308+
)
309+
old_room_state = await tombstone_context.get_current_state_ids(state_filter)
307310

308311
# We know the tombstone event isn't an outlier so it has current state.
309312
assert old_room_state is not None

Diff for: synapse/handlers/room_member.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from synapse.events import EventBase
3939
from synapse.events.snapshot import EventContext
4040
from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN
41+
from synapse.storage.state import StateFilter
4142
from synapse.types import (
4243
JsonDict,
4344
Requester,
@@ -362,7 +363,9 @@ async def _local_membership_update(
362363
historical=historical,
363364
)
364365

365-
prev_state_ids = await context.get_prev_state_ids()
366+
prev_state_ids = await context.get_prev_state_ids(
367+
StateFilter.from_types([(EventTypes.Member, None)])
368+
)
366369

367370
prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
368371

@@ -1160,7 +1163,9 @@ async def send_membership_event(
11601163
else:
11611164
requester = types.create_requester(target_user)
11621165

1163-
prev_state_ids = await context.get_prev_state_ids()
1166+
prev_state_ids = await context.get_prev_state_ids(
1167+
StateFilter.from_types([(EventTypes.GuestAccess, None)])
1168+
)
11641169
if event.membership == Membership.JOIN:
11651170
if requester.is_guest:
11661171
guest_can_join = await self._can_guest_join(prev_state_ids)

Diff for: synapse/push/bulk_push_rule_evaluator.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from prometheus_client import Counter
2121

2222
from synapse.api.constants import EventTypes, Membership, RelationTypes
23-
from synapse.event_auth import get_user_power_level
23+
from synapse.event_auth import auth_types_for_event, get_user_power_level
2424
from synapse.events import EventBase, relation_from_event
2525
from synapse.events.snapshot import EventContext
2626
from synapse.state import POWER_KEY
@@ -31,6 +31,7 @@
3131
from synapse.util.caches.lrucache import LruCache
3232
from synapse.util.metrics import measure_func
3333

34+
from ..storage.state import StateFilter
3435
from .push_rule_evaluator import PushRuleEvaluatorForEvent
3536

3637
if TYPE_CHECKING:
@@ -168,8 +169,12 @@ def _get_rules_for_room(self, room_id: str) -> "RulesForRoomData":
168169
async def _get_power_levels_and_sender_level(
169170
self, event: EventBase, context: EventContext
170171
) -> Tuple[dict, int]:
171-
prev_state_ids = await context.get_prev_state_ids()
172+
event_types = auth_types_for_event(event.room_version, event)
173+
prev_state_ids = await context.get_prev_state_ids(
174+
StateFilter.from_types(event_types)
175+
)
172176
pl_event_id = prev_state_ids.get(POWER_KEY)
177+
173178
if pl_event_id:
174179
# fastpath: if there's a power level event, that's all we need, and
175180
# not having a power level event is an extreme edge case

Diff for: synapse/storage/state.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -634,16 +634,19 @@ async def get_state_groups_ids(
634634

635635
return group_to_state
636636

637-
async def get_state_ids_for_group(self, state_group: int) -> StateMap[str]:
637+
async def get_state_ids_for_group(
638+
self, state_group: int, state_filter: Optional[StateFilter] = None
639+
) -> StateMap[str]:
638640
"""Get the event IDs of all the state in the given state group
639641
640642
Args:
641643
state_group: A state group for which we want to get the state IDs.
644+
state_filter: specifies the type of state event to fetch from DB, example: EventTypes.JoinRules
642645
643646
Returns:
644647
Resolves to a map of (type, state_key) -> event_id
645648
"""
646-
group_to_state = await self.get_state_for_groups((state_group,))
649+
group_to_state = await self.get_state_for_groups((state_group,), state_filter)
647650

648651
return group_to_state[state_group]
649652

Diff for: tests/test_state.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ async def get_state_groups_ids(self, room_id, event_ids):
8888

8989
return groups
9090

91-
async def get_state_ids_for_group(self, state_group):
91+
async def get_state_ids_for_group(self, state_group, state_filter=None):
9292
return self._group_to_state[state_group]
9393

9494
async def store_state_group(

0 commit comments

Comments
 (0)