|
18 | 18 | import logging
|
19 | 19 | from typing import (
|
20 | 20 | Any,
|
| 21 | + Awaitable, |
21 | 22 | Callable,
|
22 | 23 | Dict,
|
23 | 24 | Generic,
|
@@ -346,15 +347,15 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
|
346 | 347 | """Wraps an existing cache to support bulk fetching of keys.
|
347 | 348 |
|
348 | 349 | Given an iterable of keys it looks in the cache to find any hits, then passes
|
349 |
| - the tuple of missing keys to the wrapped function. |
| 350 | + the set of missing keys to the wrapped function. |
350 | 351 |
|
351 |
| - Once wrapped, the function returns a Deferred which resolves to the list |
352 |
| - of results. |
| 352 | + Once wrapped, the function returns a Deferred which resolves to a Dict mapping from |
| 353 | + input key to output value. |
353 | 354 | """
|
354 | 355 |
|
355 | 356 | def __init__(
|
356 | 357 | self,
|
357 |
| - orig: Callable[..., Any], |
| 358 | + orig: Callable[..., Awaitable[Dict]], |
358 | 359 | cached_method_name: str,
|
359 | 360 | list_name: str,
|
360 | 361 | num_args: Optional[int] = None,
|
@@ -385,13 +386,13 @@ def __init__(
|
385 | 386 |
|
386 | 387 | def __get__(
|
387 | 388 | self, obj: Optional[Any], objtype: Optional[Type] = None
|
388 |
| - ) -> Callable[..., Any]: |
| 389 | + ) -> Callable[..., "defer.Deferred[Dict[Hashable, Any]]"]: |
389 | 390 | cached_method = getattr(obj, self.cached_method_name)
|
390 | 391 | cache: DeferredCache[CacheKey, Any] = cached_method.cache
|
391 | 392 | num_args = cached_method.num_args
|
392 | 393 |
|
393 | 394 | @functools.wraps(self.orig)
|
394 |
| - def wrapped(*args: Any, **kwargs: Any) -> Any: |
| 395 | + def wrapped(*args: Any, **kwargs: Any) -> "defer.Deferred[Dict]": |
395 | 396 | # If we're passed a cache_context then we'll want to call its
|
396 | 397 | # invalidate() whenever we are invalidated
|
397 | 398 | invalidate_callback = kwargs.pop("on_invalidate", None)
|
@@ -444,39 +445,38 @@ def arg_to_cache_key(arg: Hashable) -> Hashable:
|
444 | 445 | deferred: "defer.Deferred[Any]" = defer.Deferred()
|
445 | 446 | deferreds_map[arg] = deferred
|
446 | 447 | key = arg_to_cache_key(arg)
|
447 |
| - cache.set(key, deferred, callback=invalidate_callback) |
| 448 | + cached_defers.append( |
| 449 | + cache.set(key, deferred, callback=invalidate_callback) |
| 450 | + ) |
448 | 451 |
|
449 | 452 | def complete_all(res: Dict[Hashable, Any]) -> None:
|
450 |
| - # the wrapped function has completed. It returns a |
451 |
| - # a dict. We can now resolve the observable deferreds in |
452 |
| - # the cache and update our own result map. |
453 |
| - for e in missing: |
| 453 | + # the wrapped function has completed. It returns a dict. |
| 454 | + # We can now update our own result map, and then resolve the |
| 455 | + # observable deferreds in the cache. |
| 456 | + for e, d1 in deferreds_map.items(): |
454 | 457 | val = res.get(e, None)
|
455 |
| - deferreds_map[e].callback(val) |
| 458 | + # make sure we update the results map before running the |
| 459 | + # deferreds, because as soon as we run the last deferred, the |
| 460 | + # gatherResults() below will complete and return the result |
| 461 | + # dict to our caller. |
456 | 462 | results[e] = val
|
| 463 | + d1.callback(val) |
457 | 464 |
|
458 |
| - def errback(f: Failure) -> Failure: |
459 |
| - # the wrapped function has failed. Invalidate any cache |
460 |
| - # entries we're supposed to be populating, and fail |
461 |
| - # their deferreds. |
462 |
| - for e in missing: |
463 |
| - key = arg_to_cache_key(e) |
464 |
| - cache.invalidate(key) |
465 |
| - deferreds_map[e].errback(f) |
466 |
| - |
467 |
| - # return the failure, to propagate to our caller. |
468 |
| - return f |
| 465 | + def errback_all(f: Failure) -> None: |
| 466 | + # the wrapped function has failed. Propagate the failure into |
| 467 | + # the cache, which will invalidate the entry, and cause the |
| 468 | + # relevant cached_deferreds to fail, which will propagate the |
| 469 | + # failure to our caller. |
| 470 | + for d1 in deferreds_map.values(): |
| 471 | + d1.errback(f) |
469 | 472 |
|
470 | 473 | args_to_call = dict(arg_dict)
|
471 |
| - # copy the missing set before sending it to the callee, to guard against |
472 |
| - # modification. |
473 |
| - args_to_call[self.list_name] = tuple(missing) |
474 |
| - |
475 |
| - cached_defers.append( |
476 |
| - defer.maybeDeferred( |
477 |
| - preserve_fn(self.orig), **args_to_call |
478 |
| - ).addCallbacks(complete_all, errback) |
479 |
| - ) |
| 474 | + args_to_call[self.list_name] = missing |
| 475 | + |
| 476 | + # dispatch the call, and attach the two handlers |
| 477 | + defer.maybeDeferred( |
| 478 | + preserve_fn(self.orig), **args_to_call |
| 479 | + ).addCallbacks(complete_all, errback_all) |
480 | 480 |
|
481 | 481 | if cached_defers:
|
482 | 482 | d = defer.gatherResults(cached_defers, consumeErrors=True).addCallbacks(
|
|
0 commit comments