Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 284ea20

Browse files
reivilibreclokep
andauthored
Track and deduplicate in-flight requests to _get_state_for_groups. (#10870)
Co-authored-by: Patrick Cloke <[email protected]>
1 parent e6acd3c commit 284ea20

File tree

3 files changed

+312
-25
lines changed

3 files changed

+312
-25
lines changed

changelog.d/10870.misc

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Deduplicate in-flight requests in `_get_state_for_groups`.

synapse/storage/databases/state/store.py

+178-25
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,23 @@
1313
# limitations under the License.
1414

1515
import logging
16-
from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
16+
from typing import (
17+
TYPE_CHECKING,
18+
Collection,
19+
Dict,
20+
Iterable,
21+
Optional,
22+
Sequence,
23+
Set,
24+
Tuple,
25+
)
1726

1827
import attr
1928

29+
from twisted.internet import defer
30+
2031
from synapse.api.constants import EventTypes
32+
from synapse.logging.context import make_deferred_yieldable, run_in_background
2133
from synapse.storage._base import SQLBaseStore
2234
from synapse.storage.database import (
2335
DatabasePool,
@@ -29,6 +41,12 @@
2941
from synapse.storage.types import Cursor
3042
from synapse.storage.util.sequence import build_sequence_generator
3143
from synapse.types import MutableStateMap, StateKey, StateMap
44+
from synapse.util import unwrapFirstError
45+
from synapse.util.async_helpers import (
46+
AbstractObservableDeferred,
47+
ObservableDeferred,
48+
yieldable_gather_results,
49+
)
3250
from synapse.util.caches.descriptors import cached
3351
from synapse.util.caches.dictionary_cache import DictionaryCache
3452

@@ -37,7 +55,6 @@
3755

3856
logger = logging.getLogger(__name__)
3957

40-
4158
MAX_STATE_DELTA_HOPS = 100
4259

4360

@@ -106,6 +123,12 @@ def __init__(
106123
500000,
107124
)
108125

126+
# Current ongoing get_state_for_groups in-flight requests
127+
# {group ID -> {StateFilter -> ObservableDeferred}}
128+
self._state_group_inflight_requests: Dict[
129+
int, Dict[StateFilter, AbstractObservableDeferred[StateMap[str]]]
130+
] = {}
131+
109132
def get_max_state_group_txn(txn: Cursor) -> int:
110133
txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
111134
return txn.fetchone()[0] # type: ignore
@@ -157,7 +180,7 @@ def _get_state_group_delta_txn(txn: LoggingTransaction) -> _GetStateGroupDelta:
157180
)
158181

159182
async def _get_state_groups_from_groups(
160-
self, groups: List[int], state_filter: StateFilter
183+
self, groups: Sequence[int], state_filter: StateFilter
161184
) -> Dict[int, StateMap[str]]:
162185
"""Returns the state groups for a given set of groups from the
163186
database, filtering on types of state events.
@@ -228,6 +251,150 @@ def _get_state_for_group_using_cache(
228251

229252
return state_filter.filter_state(state_dict_ids), not missing_types
230253

254+
def _get_state_for_group_gather_inflight_requests(
255+
self, group: int, state_filter_left_over: StateFilter
256+
) -> Tuple[Sequence[AbstractObservableDeferred[StateMap[str]]], StateFilter]:
257+
"""
258+
Attempts to gather in-flight requests and re-use them to retrieve state
259+
for the given state group, filtered with the given state filter.
260+
261+
Used as part of _get_state_for_group_using_inflight_cache.
262+
263+
Returns:
264+
Tuple of two values:
265+
A sequence of ObservableDeferreds to observe
266+
A StateFilter representing what else needs to be requested to fulfill the request
267+
"""
268+
269+
inflight_requests = self._state_group_inflight_requests.get(group)
270+
if inflight_requests is None:
271+
# no requests for this group, need to retrieve it all ourselves
272+
return (), state_filter_left_over
273+
274+
# The list of ongoing requests which will help narrow the current request.
275+
reusable_requests = []
276+
for (request_state_filter, request_deferred) in inflight_requests.items():
277+
new_state_filter_left_over = state_filter_left_over.approx_difference(
278+
request_state_filter
279+
)
280+
if new_state_filter_left_over == state_filter_left_over:
281+
# Reusing this request would not gain us anything, so don't bother.
282+
continue
283+
284+
reusable_requests.append(request_deferred)
285+
state_filter_left_over = new_state_filter_left_over
286+
if state_filter_left_over == StateFilter.none():
287+
# we have managed to collect enough of the in-flight requests
288+
# to cover our StateFilter and give us the state we need.
289+
break
290+
291+
return reusable_requests, state_filter_left_over
292+
293+
async def _get_state_for_group_fire_request(
294+
self, group: int, state_filter: StateFilter
295+
) -> StateMap[str]:
296+
"""
297+
Fires off a request to get the state at a state group,
298+
potentially filtering by type and/or state key.
299+
300+
This request will be tracked in the in-flight request cache and automatically
301+
removed when it is finished.
302+
303+
Used as part of _get_state_for_group_using_inflight_cache.
304+
305+
Args:
306+
group: ID of the state group for which we want to get state
307+
state_filter: the state filter used to fetch state from the database
308+
"""
309+
cache_sequence_nm = self._state_group_cache.sequence
310+
cache_sequence_m = self._state_group_members_cache.sequence
311+
312+
# Help the cache hit ratio by expanding the filter a bit
313+
db_state_filter = state_filter.return_expanded()
314+
315+
async def _the_request() -> StateMap[str]:
316+
group_to_state_dict = await self._get_state_groups_from_groups(
317+
(group,), state_filter=db_state_filter
318+
)
319+
320+
# Now let's update the caches
321+
self._insert_into_cache(
322+
group_to_state_dict,
323+
db_state_filter,
324+
cache_seq_num_members=cache_sequence_m,
325+
cache_seq_num_non_members=cache_sequence_nm,
326+
)
327+
328+
# Remove ourselves from the in-flight cache
329+
group_request_dict = self._state_group_inflight_requests[group]
330+
del group_request_dict[db_state_filter]
331+
if not group_request_dict:
332+
# If there are no more requests in-flight for this group,
333+
# clean up the cache by removing the empty dictionary
334+
del self._state_group_inflight_requests[group]
335+
336+
return group_to_state_dict[group]
337+
338+
# We don't immediately await the result, so must use run_in_background
339+
# But we DO await the result before the current log context (request)
340+
# finishes, so don't need to run it as a background process.
341+
request_deferred = run_in_background(_the_request)
342+
observable_deferred = ObservableDeferred(request_deferred, consumeErrors=True)
343+
344+
# Insert the ObservableDeferred into the cache
345+
group_request_dict = self._state_group_inflight_requests.setdefault(group, {})
346+
group_request_dict[db_state_filter] = observable_deferred
347+
348+
return await make_deferred_yieldable(observable_deferred.observe())
349+
350+
async def _get_state_for_group_using_inflight_cache(
351+
self, group: int, state_filter: StateFilter
352+
) -> MutableStateMap[str]:
353+
"""
354+
Gets the state at a state group, potentially filtering by type and/or
355+
state key.
356+
357+
1. Calls _get_state_for_group_gather_inflight_requests to gather any
358+
ongoing requests which might overlap with the current request.
359+
2. Fires a new request, using _get_state_for_group_fire_request,
360+
for any state which cannot be gathered from ongoing requests.
361+
362+
Args:
363+
group: ID of the state group for which we want to get state
364+
state_filter: the state filter used to fetch state from the database
365+
Returns:
366+
state map
367+
"""
368+
369+
# first, figure out whether we can re-use any in-flight requests
370+
# (and if so, what would be left over)
371+
(
372+
reusable_requests,
373+
state_filter_left_over,
374+
) = self._get_state_for_group_gather_inflight_requests(group, state_filter)
375+
376+
if state_filter_left_over != StateFilter.none():
377+
# Fetch remaining state
378+
remaining = await self._get_state_for_group_fire_request(
379+
group, state_filter_left_over
380+
)
381+
assembled_state: MutableStateMap[str] = dict(remaining)
382+
else:
383+
assembled_state = {}
384+
385+
gathered = await make_deferred_yieldable(
386+
defer.gatherResults(
387+
(r.observe() for r in reusable_requests), consumeErrors=True
388+
)
389+
).addErrback(unwrapFirstError)
390+
391+
# assemble our result.
392+
for result_piece in gathered:
393+
assembled_state.update(result_piece)
394+
395+
# Filter out any state that may be more than what we asked for.
396+
return state_filter.filter_state(assembled_state)
397+
231398
async def _get_state_for_groups(
232399
self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
233400
) -> Dict[int, MutableStateMap[str]]:
@@ -269,31 +436,17 @@ async def _get_state_for_groups(
269436
if not incomplete_groups:
270437
return state
271438

272-
cache_sequence_nm = self._state_group_cache.sequence
273-
cache_sequence_m = self._state_group_members_cache.sequence
274-
275-
# Help the cache hit ratio by expanding the filter a bit
276-
db_state_filter = state_filter.return_expanded()
277-
278-
group_to_state_dict = await self._get_state_groups_from_groups(
279-
list(incomplete_groups), state_filter=db_state_filter
280-
)
439+
async def get_from_cache(group: int, state_filter: StateFilter) -> None:
440+
state[group] = await self._get_state_for_group_using_inflight_cache(
441+
group, state_filter
442+
)
281443

282-
# Now lets update the caches
283-
self._insert_into_cache(
284-
group_to_state_dict,
285-
db_state_filter,
286-
cache_seq_num_members=cache_sequence_m,
287-
cache_seq_num_non_members=cache_sequence_nm,
444+
await yieldable_gather_results(
445+
get_from_cache,
446+
incomplete_groups,
447+
state_filter,
288448
)
289449

290-
# And finally update the result dict, by filtering out any extra
291-
# stuff we pulled out of the database.
292-
for group, group_state_dict in group_to_state_dict.items():
293-
# We just replace any existing entries, as we will have loaded
294-
# everything we need from the database anyway.
295-
state[group] = state_filter.filter_state(group_state_dict)
296-
297450
return state
298451

299452
def _get_state_for_groups_using_cache(
+133
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
# Copyright 2022 The Matrix.org Foundation C.I.C.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import typing
15+
from typing import Dict, List, Sequence, Tuple
16+
from unittest.mock import patch
17+
18+
from twisted.internet.defer import Deferred, ensureDeferred
19+
from twisted.test.proto_helpers import MemoryReactor
20+
21+
from synapse.storage.state import StateFilter
22+
from synapse.types import MutableStateMap, StateMap
23+
from synapse.util import Clock
24+
25+
from tests.unittest import HomeserverTestCase
26+
27+
if typing.TYPE_CHECKING:
28+
from synapse.server import HomeServer
29+
30+
31+
class StateGroupInflightCachingTestCase(HomeserverTestCase):
32+
def prepare(
33+
self, reactor: MemoryReactor, clock: Clock, homeserver: "HomeServer"
34+
) -> None:
35+
self.state_storage = homeserver.get_storage().state
36+
self.state_datastore = homeserver.get_datastores().state
37+
# Patch out the `_get_state_groups_from_groups`.
38+
# This is useful because it lets us pretend we have a slow database.
39+
get_state_groups_patch = patch.object(
40+
self.state_datastore,
41+
"_get_state_groups_from_groups",
42+
self._fake_get_state_groups_from_groups,
43+
)
44+
get_state_groups_patch.start()
45+
46+
self.addCleanup(get_state_groups_patch.stop)
47+
self.get_state_group_calls: List[
48+
Tuple[Tuple[int, ...], StateFilter, Deferred[Dict[int, StateMap[str]]]]
49+
] = []
50+
51+
def _fake_get_state_groups_from_groups(
52+
self, groups: Sequence[int], state_filter: StateFilter
53+
) -> "Deferred[Dict[int, StateMap[str]]]":
54+
d: Deferred[Dict[int, StateMap[str]]] = Deferred()
55+
self.get_state_group_calls.append((tuple(groups), state_filter, d))
56+
return d
57+
58+
def _complete_request_fake(
59+
self,
60+
groups: Tuple[int, ...],
61+
state_filter: StateFilter,
62+
d: "Deferred[Dict[int, StateMap[str]]]",
63+
) -> None:
64+
"""
65+
Assemble a fake database response and complete the database request.
66+
"""
67+
68+
result: Dict[int, StateMap[str]] = {}
69+
70+
for group in groups:
71+
group_result: MutableStateMap[str] = {}
72+
result[group] = group_result
73+
74+
for state_type, state_keys in state_filter.types.items():
75+
if state_keys is None:
76+
group_result[(state_type, "a")] = "xyz"
77+
group_result[(state_type, "b")] = "xyz"
78+
else:
79+
for state_key in state_keys:
80+
group_result[(state_type, state_key)] = "abc"
81+
82+
if state_filter.include_others:
83+
group_result[("other.event.type", "state.key")] = "123"
84+
85+
d.callback(result)
86+
87+
def test_duplicate_requests_deduplicated(self) -> None:
88+
"""
89+
Tests that duplicate requests for state are deduplicated.
90+
91+
This test:
92+
- requests some state (state group 42, 'all' state filter)
93+
- requests it again, before the first request finishes
94+
- checks to see that only one database query was made
95+
- completes the database query
96+
- checks that both requests see the same retrieved state
97+
"""
98+
req1 = ensureDeferred(
99+
self.state_datastore._get_state_for_group_using_inflight_cache(
100+
42, StateFilter.all()
101+
)
102+
)
103+
self.pump(by=0.1)
104+
105+
# This should have gone to the database
106+
self.assertEqual(len(self.get_state_group_calls), 1)
107+
self.assertFalse(req1.called)
108+
109+
req2 = ensureDeferred(
110+
self.state_datastore._get_state_for_group_using_inflight_cache(
111+
42, StateFilter.all()
112+
)
113+
)
114+
self.pump(by=0.1)
115+
116+
# No more calls should have gone to the database
117+
self.assertEqual(len(self.get_state_group_calls), 1)
118+
self.assertFalse(req1.called)
119+
self.assertFalse(req2.called)
120+
121+
groups, sf, d = self.get_state_group_calls[0]
122+
self.assertEqual(groups, (42,))
123+
self.assertEqual(sf, StateFilter.all())
124+
125+
# Now we can complete the request
126+
self._complete_request_fake(groups, sf, d)
127+
128+
self.assertEqual(
129+
self.get_success(req1), {("other.event.type", "state.key"): "123"}
130+
)
131+
self.assertEqual(
132+
self.get_success(req2), {("other.event.type", "state.key"): "123"}
133+
)

0 commit comments

Comments
 (0)