|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
15 | 15 | import logging
|
16 |
| -from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple |
| 16 | +from typing import ( |
| 17 | + TYPE_CHECKING, |
| 18 | + Collection, |
| 19 | + Dict, |
| 20 | + Iterable, |
| 21 | + Optional, |
| 22 | + Sequence, |
| 23 | + Set, |
| 24 | + Tuple, |
| 25 | +) |
17 | 26 |
|
18 | 27 | import attr
|
19 | 28 |
|
| 29 | +from twisted.internet import defer |
| 30 | + |
20 | 31 | from synapse.api.constants import EventTypes
|
| 32 | +from synapse.logging.context import make_deferred_yieldable, run_in_background |
21 | 33 | from synapse.storage._base import SQLBaseStore
|
22 | 34 | from synapse.storage.database import (
|
23 | 35 | DatabasePool,
|
|
29 | 41 | from synapse.storage.types import Cursor
|
30 | 42 | from synapse.storage.util.sequence import build_sequence_generator
|
31 | 43 | from synapse.types import MutableStateMap, StateKey, StateMap
|
| 44 | +from synapse.util import unwrapFirstError |
| 45 | +from synapse.util.async_helpers import ( |
| 46 | + AbstractObservableDeferred, |
| 47 | + ObservableDeferred, |
| 48 | + yieldable_gather_results, |
| 49 | +) |
32 | 50 | from synapse.util.caches.descriptors import cached
|
33 | 51 | from synapse.util.caches.dictionary_cache import DictionaryCache
|
34 | 52 |
|
|
37 | 55 |
|
38 | 56 | logger = logging.getLogger(__name__)
|
39 | 57 |
|
40 |
| - |
41 | 58 | MAX_STATE_DELTA_HOPS = 100
|
42 | 59 |
|
43 | 60 |
|
@@ -106,6 +123,12 @@ def __init__(
|
106 | 123 | 500000,
|
107 | 124 | )
|
108 | 125 |
|
| 126 | + # Current ongoing get_state_for_groups in-flight requests |
| 127 | + # {group ID -> {StateFilter -> ObservableDeferred}} |
| 128 | + self._state_group_inflight_requests: Dict[ |
| 129 | + int, Dict[StateFilter, AbstractObservableDeferred[StateMap[str]]] |
| 130 | + ] = {} |
| 131 | + |
109 | 132 | def get_max_state_group_txn(txn: Cursor) -> int:
|
110 | 133 | txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
|
111 | 134 | return txn.fetchone()[0] # type: ignore
|
@@ -157,7 +180,7 @@ def _get_state_group_delta_txn(txn: LoggingTransaction) -> _GetStateGroupDelta:
|
157 | 180 | )
|
158 | 181 |
|
159 | 182 | async def _get_state_groups_from_groups(
|
160 |
| - self, groups: List[int], state_filter: StateFilter |
| 183 | + self, groups: Sequence[int], state_filter: StateFilter |
161 | 184 | ) -> Dict[int, StateMap[str]]:
|
162 | 185 | """Returns the state groups for a given set of groups from the
|
163 | 186 | database, filtering on types of state events.
|
@@ -228,6 +251,150 @@ def _get_state_for_group_using_cache(
|
228 | 251 |
|
229 | 252 | return state_filter.filter_state(state_dict_ids), not missing_types
|
230 | 253 |
|
| 254 | + def _get_state_for_group_gather_inflight_requests( |
| 255 | + self, group: int, state_filter_left_over: StateFilter |
| 256 | + ) -> Tuple[Sequence[AbstractObservableDeferred[StateMap[str]]], StateFilter]: |
| 257 | + """ |
| 258 | + Attempts to gather in-flight requests and re-use them to retrieve state |
| 259 | + for the given state group, filtered with the given state filter. |
| 260 | +
|
| 261 | + Used as part of _get_state_for_group_using_inflight_cache. |
| 262 | +
|
| 263 | + Returns: |
| 264 | + Tuple of two values: |
| 265 | + A sequence of ObservableDeferreds to observe |
| 266 | + A StateFilter representing what else needs to be requested to fulfill the request |
| 267 | + """ |
| 268 | + |
| 269 | + inflight_requests = self._state_group_inflight_requests.get(group) |
| 270 | + if inflight_requests is None: |
| 271 | + # no requests for this group, need to retrieve it all ourselves |
| 272 | + return (), state_filter_left_over |
| 273 | + |
| 274 | + # The list of ongoing requests which will help narrow the current request. |
| 275 | + reusable_requests = [] |
| 276 | + for (request_state_filter, request_deferred) in inflight_requests.items(): |
| 277 | + new_state_filter_left_over = state_filter_left_over.approx_difference( |
| 278 | + request_state_filter |
| 279 | + ) |
| 280 | + if new_state_filter_left_over == state_filter_left_over: |
| 281 | + # Reusing this request would not gain us anything, so don't bother. |
| 282 | + continue |
| 283 | + |
| 284 | + reusable_requests.append(request_deferred) |
| 285 | + state_filter_left_over = new_state_filter_left_over |
| 286 | + if state_filter_left_over == StateFilter.none(): |
| 287 | + # we have managed to collect enough of the in-flight requests |
| 288 | + # to cover our StateFilter and give us the state we need. |
| 289 | + break |
| 290 | + |
| 291 | + return reusable_requests, state_filter_left_over |
| 292 | + |
| 293 | + async def _get_state_for_group_fire_request( |
| 294 | + self, group: int, state_filter: StateFilter |
| 295 | + ) -> StateMap[str]: |
| 296 | + """ |
| 297 | + Fires off a request to get the state at a state group, |
| 298 | + potentially filtering by type and/or state key. |
| 299 | +
|
| 300 | + This request will be tracked in the in-flight request cache and automatically |
| 301 | + removed when it is finished. |
| 302 | +
|
| 303 | + Used as part of _get_state_for_group_using_inflight_cache. |
| 304 | +
|
| 305 | + Args: |
| 306 | + group: ID of the state group for which we want to get state |
| 307 | + state_filter: the state filter used to fetch state from the database |
| 308 | + """ |
| 309 | + cache_sequence_nm = self._state_group_cache.sequence |
| 310 | + cache_sequence_m = self._state_group_members_cache.sequence |
| 311 | + |
| 312 | + # Help the cache hit ratio by expanding the filter a bit |
| 313 | + db_state_filter = state_filter.return_expanded() |
| 314 | + |
| 315 | + async def _the_request() -> StateMap[str]: |
| 316 | + group_to_state_dict = await self._get_state_groups_from_groups( |
| 317 | + (group,), state_filter=db_state_filter |
| 318 | + ) |
| 319 | + |
| 320 | + # Now let's update the caches |
| 321 | + self._insert_into_cache( |
| 322 | + group_to_state_dict, |
| 323 | + db_state_filter, |
| 324 | + cache_seq_num_members=cache_sequence_m, |
| 325 | + cache_seq_num_non_members=cache_sequence_nm, |
| 326 | + ) |
| 327 | + |
| 328 | + # Remove ourselves from the in-flight cache |
| 329 | + group_request_dict = self._state_group_inflight_requests[group] |
| 330 | + del group_request_dict[db_state_filter] |
| 331 | + if not group_request_dict: |
| 332 | + # If there are no more requests in-flight for this group, |
| 333 | + # clean up the cache by removing the empty dictionary |
| 334 | + del self._state_group_inflight_requests[group] |
| 335 | + |
| 336 | + return group_to_state_dict[group] |
| 337 | + |
| 338 | + # We don't immediately await the result, so must use run_in_background |
| 339 | + # But we DO await the result before the current log context (request) |
| 340 | + # finishes, so don't need to run it as a background process. |
| 341 | + request_deferred = run_in_background(_the_request) |
| 342 | + observable_deferred = ObservableDeferred(request_deferred, consumeErrors=True) |
| 343 | + |
| 344 | + # Insert the ObservableDeferred into the cache |
| 345 | + group_request_dict = self._state_group_inflight_requests.setdefault(group, {}) |
| 346 | + group_request_dict[db_state_filter] = observable_deferred |
| 347 | + |
| 348 | + return await make_deferred_yieldable(observable_deferred.observe()) |
| 349 | + |
| 350 | + async def _get_state_for_group_using_inflight_cache( |
| 351 | + self, group: int, state_filter: StateFilter |
| 352 | + ) -> MutableStateMap[str]: |
| 353 | + """ |
| 354 | + Gets the state at a state group, potentially filtering by type and/or |
| 355 | + state key. |
| 356 | +
|
| 357 | + 1. Calls _get_state_for_group_gather_inflight_requests to gather any |
| 358 | + ongoing requests which might overlap with the current request. |
| 359 | + 2. Fires a new request, using _get_state_for_group_fire_request, |
| 360 | + for any state which cannot be gathered from ongoing requests. |
| 361 | +
|
| 362 | + Args: |
| 363 | + group: ID of the state group for which we want to get state |
| 364 | + state_filter: the state filter used to fetch state from the database |
| 365 | + Returns: |
| 366 | + state map |
| 367 | + """ |
| 368 | + |
| 369 | + # first, figure out whether we can re-use any in-flight requests |
| 370 | + # (and if so, what would be left over) |
| 371 | + ( |
| 372 | + reusable_requests, |
| 373 | + state_filter_left_over, |
| 374 | + ) = self._get_state_for_group_gather_inflight_requests(group, state_filter) |
| 375 | + |
| 376 | + if state_filter_left_over != StateFilter.none(): |
| 377 | + # Fetch remaining state |
| 378 | + remaining = await self._get_state_for_group_fire_request( |
| 379 | + group, state_filter_left_over |
| 380 | + ) |
| 381 | + assembled_state: MutableStateMap[str] = dict(remaining) |
| 382 | + else: |
| 383 | + assembled_state = {} |
| 384 | + |
| 385 | + gathered = await make_deferred_yieldable( |
| 386 | + defer.gatherResults( |
| 387 | + (r.observe() for r in reusable_requests), consumeErrors=True |
| 388 | + ) |
| 389 | + ).addErrback(unwrapFirstError) |
| 390 | + |
| 391 | + # assemble our result. |
| 392 | + for result_piece in gathered: |
| 393 | + assembled_state.update(result_piece) |
| 394 | + |
| 395 | + # Filter out any state that may be more than what we asked for. |
| 396 | + return state_filter.filter_state(assembled_state) |
| 397 | + |
231 | 398 | async def _get_state_for_groups(
|
232 | 399 | self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
|
233 | 400 | ) -> Dict[int, MutableStateMap[str]]:
|
@@ -269,31 +436,17 @@ async def _get_state_for_groups(
|
269 | 436 | if not incomplete_groups:
|
270 | 437 | return state
|
271 | 438 |
|
272 |
| - cache_sequence_nm = self._state_group_cache.sequence |
273 |
| - cache_sequence_m = self._state_group_members_cache.sequence |
274 |
| - |
275 |
| - # Help the cache hit ratio by expanding the filter a bit |
276 |
| - db_state_filter = state_filter.return_expanded() |
277 |
| - |
278 |
| - group_to_state_dict = await self._get_state_groups_from_groups( |
279 |
| - list(incomplete_groups), state_filter=db_state_filter |
280 |
| - ) |
| 439 | + async def get_from_cache(group: int, state_filter: StateFilter) -> None: |
| 440 | + state[group] = await self._get_state_for_group_using_inflight_cache( |
| 441 | + group, state_filter |
| 442 | + ) |
281 | 443 |
|
282 |
| - # Now lets update the caches |
283 |
| - self._insert_into_cache( |
284 |
| - group_to_state_dict, |
285 |
| - db_state_filter, |
286 |
| - cache_seq_num_members=cache_sequence_m, |
287 |
| - cache_seq_num_non_members=cache_sequence_nm, |
| 444 | + await yieldable_gather_results( |
| 445 | + get_from_cache, |
| 446 | + incomplete_groups, |
| 447 | + state_filter, |
288 | 448 | )
|
289 | 449 |
|
290 |
| - # And finally update the result dict, by filtering out any extra |
291 |
| - # stuff we pulled out of the database. |
292 |
| - for group, group_state_dict in group_to_state_dict.items(): |
293 |
| - # We just replace any existing entries, as we will have loaded |
294 |
| - # everything we need from the database anyway. |
295 |
| - state[group] = state_filter.filter_state(group_state_dict) |
296 |
| - |
297 | 450 | return state
|
298 | 451 |
|
299 | 452 | def _get_state_for_groups_using_cache(
|
|
0 commit comments