|
4 | 4 | import logging
|
5 | 5 | from collections.abc import Callable, Coroutine
|
6 | 6 | from datetime import timedelta
|
7 |
| -from typing import Any, ParamSpec, TypeVar |
| 7 | +from typing import Any, Final, ParamSpec, TypeVar |
8 | 8 |
|
9 | 9 | import redis.exceptions
|
| 10 | +from redis.asyncio.lock import Lock |
| 11 | +from tenacity import retry |
10 | 12 |
|
11 |
| -from ..background_task import periodic_task |
12 |
| -from ..logging_utils import log_context |
| 13 | +from servicelib.async_utils import cancel_wait_task, with_delay |
| 14 | +from servicelib.logging_utils import log_context |
| 15 | + |
| 16 | +from ..background_task import periodic |
13 | 17 | from ._client import RedisClientSDK
|
14 | 18 | from ._constants import DEFAULT_LOCK_TTL, SHUTDOWN_TIMEOUT_S
|
15 |
| -from ._errors import CouldNotAcquireLockError |
| 19 | +from ._errors import CouldNotAcquireLockError, LockLostError |
16 | 20 | from ._utils import auto_extend_lock
|
17 | 21 |
|
18 | 22 | _logger = logging.getLogger(__file__)
|
19 | 23 |
|
20 | 24 | P = ParamSpec("P")
|
21 | 25 | R = TypeVar("R")
|
22 | 26 |
|
| 27 | +_EXCLUSIVE_TASK_NAME: Final[str] = "exclusive/{func_name}" |
| 28 | +_EXCLUSIVE_AUTO_EXTEND_TASK_NAME: Final[str] = ( |
| 29 | + "exclusive/autoextend_lock_{redis_lock_key}" |
| 30 | +) |
| 31 | + |
| 32 | + |
| 33 | +@periodic(interval=DEFAULT_LOCK_TTL / 2, raise_on_error=True) |
| 34 | +async def _periodic_auto_extender(lock: Lock, started_event: asyncio.Event) -> None: |
| 35 | + started_event.set() |
| 36 | + await auto_extend_lock(lock) |
| 37 | + current_task = asyncio.tasks.current_task() |
| 38 | + assert current_task is not None # nosec |
| 39 | + print(current_task.cancelling()) |
| 40 | + |
| 41 | + |
| 42 | +def _cancel_auto_extender_task( |
| 43 | + _: asyncio.Task, *, auto_extend_task: asyncio.Task |
| 44 | +) -> None: |
| 45 | + with log_context( |
| 46 | + _logger, |
| 47 | + logging.DEBUG, |
| 48 | + f"Cancelling auto-extend task {auto_extend_task.get_name()}", |
| 49 | + ): |
| 50 | + auto_extend_task.cancel() |
| 51 | + assert auto_extend_task.cancelling() |
| 52 | + |
23 | 53 |
|
24 | 54 | def exclusive(
|
25 | 55 | redis_client: RedisClientSDK | Callable[..., RedisClientSDK],
|
@@ -78,42 +108,58 @@ async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
78 | 108 | raise CouldNotAcquireLockError(lock=lock)
|
79 | 109 |
|
80 | 110 | try:
|
81 |
| - async with periodic_task( |
82 |
| - auto_extend_lock, |
83 |
| - interval=DEFAULT_LOCK_TTL / 2, |
84 |
| - task_name=f"autoextend_exclusive_lock_{redis_lock_key}", |
85 |
| - raise_on_error=True, |
86 |
| - lock=lock, |
87 |
| - ) as auto_extend_task: |
88 |
| - assert asyncio.iscoroutinefunction(func) # nosec |
89 |
| - work_task = asyncio.create_task( |
90 |
| - func(*args, **kwargs), name=f"exclusive_{func.__name__}" |
91 |
| - ) |
92 |
| - done, _pending = await asyncio.wait( |
93 |
| - [work_task, auto_extend_task], |
94 |
| - return_when=asyncio.FIRST_COMPLETED, |
| 111 | + async with asyncio.TaskGroup() as tg: |
| 112 | + started_event = asyncio.Event() |
| 113 | + # first create a task that will auto-extend the lock |
| 114 | + auto_extend_lock_task = tg.create_task( |
| 115 | + _periodic_auto_extender(lock, started_event), |
| 116 | + name=_EXCLUSIVE_AUTO_EXTEND_TASK_NAME.format( |
| 117 | + redis_lock_key=redis_lock_key |
| 118 | + ), |
95 | 119 | )
|
96 |
| - # the task finished, let's return its result whatever it is |
97 |
| - if work_task in done: |
98 |
| - return await work_task |
99 |
| - |
100 |
| - # the auto extend task can only finish if it raised an error, so it's bad |
101 |
| - _logger.error( |
102 |
| - "lock %s could not be auto-extended, cancelling work task! " |
103 |
| - "TIP: check connection to Redis DBs or look for Synchronous " |
104 |
| - "code that might block the auto-extender task.", |
105 |
| - lock.name, |
| 120 | + # NOTE: in case the work task is super short lived, then we might fail in cancelling it |
| 121 | + await started_event.wait() |
| 122 | + |
| 123 | + # then the task that runs the user code |
| 124 | + assert asyncio.iscoroutinefunction(func) # nosec |
| 125 | + work_task = tg.create_task( |
| 126 | + func(*args, **kwargs), |
| 127 | + name=_EXCLUSIVE_TASK_NAME.format(func_name=func.__name__), |
106 | 128 | )
|
107 |
| - with log_context(_logger, logging.DEBUG, msg="cancel work task"): |
108 |
| - work_task.cancel() |
109 |
| - with contextlib.suppress(asyncio.CancelledError, TimeoutError): |
110 |
| - # this will raise any other error that could have happened in the work task |
111 |
| - await asyncio.wait_for( |
112 |
| - work_task, timeout=SHUTDOWN_TIMEOUT_S |
113 |
| - ) |
114 |
| - # return the extend task raised error |
115 |
| - return await auto_extend_task # type: ignore[no-any-return] # will raise |
116 | 129 |
|
| 130 | + # work_task.add_done_callback( |
| 131 | + # functools.partial( |
| 132 | + # _cancel_auto_extender_task, |
| 133 | + # auto_extend_task=auto_extend_lock_task, |
| 134 | + # ) |
| 135 | + # ) |
| 136 | + |
| 137 | + res = await work_task |
| 138 | + auto_extend_lock_task.cancel() |
| 139 | + |
| 140 | + return res |
| 141 | + |
| 142 | + except BaseExceptionGroup as eg: |
| 143 | + breakpoint() |
| 144 | + # Separate exceptions into LockLostError and others |
| 145 | + lock_lost_errors, other_errors = eg.split(LockLostError) |
| 146 | + |
| 147 | + # If there are any other errors, re-raise them |
| 148 | + if other_errors: |
| 149 | + assert len(other_errors.exceptions) == 1 # nosec |
| 150 | + raise other_errors.exceptions[0] from eg |
| 151 | + |
| 152 | + assert lock_lost_errors is not None # nosec |
| 153 | + assert len(lock_lost_errors.exceptions) == 1 # nosec |
| 154 | + _logger.error( # noqa: TRY400 |
| 155 | + "lock %s could not be auto-extended! " |
| 156 | + "TIP: check connection to Redis DBs or look for Synchronous " |
| 157 | + "code that might block the auto-extender task. Somehow the distributed lock disappeared!", |
| 158 | + lock.name, |
| 159 | + ) |
| 160 | + raise lock_lost_errors.exceptions[0] from eg |
| 161 | + except Exception as exc: |
| 162 | + breakpoint() |
117 | 163 | finally:
|
118 | 164 | with contextlib.suppress(redis.exceptions.LockNotOwnedError):
|
119 | 165 | # in the case where the lock would have been lost,
|
|
0 commit comments