|
1 |
| -from typing import Final |
| 1 | +import datetime |
| 2 | +import logging |
| 3 | +from asyncio import CancelledError |
| 4 | +from collections.abc import AsyncGenerator, Awaitable |
| 5 | +from typing import Any, Final |
2 | 6 |
|
| 7 | +from attr import dataclass |
3 | 8 | from models_library.api_schemas_rpc_async_jobs.async_jobs import (
|
4 | 9 | AsyncJobGet,
|
5 | 10 | AsyncJobId,
|
|
9 | 14 | )
|
10 | 15 | from models_library.rabbitmq_basic_types import RPCMethodName, RPCNamespace
|
11 | 16 | from pydantic import NonNegativeInt, TypeAdapter
|
| 17 | +from tenacity import ( |
| 18 | + AsyncRetrying, |
| 19 | + TryAgain, |
| 20 | + before_sleep_log, |
| 21 | + retry, |
| 22 | + retry_if_exception_type, |
| 23 | + stop_after_delay, |
| 24 | + wait_fixed, |
| 25 | + wait_random_exponential, |
| 26 | +) |
12 | 27 |
|
| 28 | +from ....rabbitmq import RemoteMethodNotRegisteredError |
13 | 29 | from ... import RabbitMQRPCClient
|
14 | 30 |
|
15 | 31 | _DEFAULT_TIMEOUT_S: Final[NonNegativeInt] = 30
|
16 | 32 |
|
17 | 33 | _RPC_METHOD_NAME_ADAPTER = TypeAdapter(RPCMethodName)
|
| 34 | +_DEFAULT_POLL_INTERVAL_S: Final[float] = 0.1 |
| 35 | +_logger = logging.getLogger(__name__) |
18 | 36 |
|
19 | 37 |
|
20 | 38 | async def cancel(
|
@@ -103,3 +121,117 @@ async def submit(
|
103 | 121 | )
|
104 | 122 | assert isinstance(_result, AsyncJobGet) # nosec
|
105 | 123 | return _result
|
| 124 | + |
| 125 | + |
| 126 | +_DEFAULT_RPC_RETRY_POLICY: dict[str, Any] = { |
| 127 | + "retry": retry_if_exception_type(RemoteMethodNotRegisteredError), |
| 128 | + "wait": wait_random_exponential(max=20), |
| 129 | + "stop": stop_after_delay(60), |
| 130 | + "reraise": True, |
| 131 | + "before_sleep": before_sleep_log(_logger, logging.INFO), |
| 132 | +} |
| 133 | + |
| 134 | + |
| 135 | +@retry(**_DEFAULT_RPC_RETRY_POLICY) |
| 136 | +async def _wait_for_completion( |
| 137 | + rabbitmq_rpc_client: RabbitMQRPCClient, |
| 138 | + *, |
| 139 | + rpc_namespace: RPCNamespace, |
| 140 | + method_name: RPCMethodName, |
| 141 | + job_id: AsyncJobId, |
| 142 | + job_id_data: AsyncJobNameData, |
| 143 | + client_timeout: datetime.timedelta, |
| 144 | +) -> AsyncGenerator[AsyncJobStatus, None]: |
| 145 | + try: |
| 146 | + async for attempt in AsyncRetrying( |
| 147 | + stop=stop_after_delay(client_timeout.total_seconds()), |
| 148 | + reraise=True, |
| 149 | + retry=retry_if_exception_type(TryAgain), |
| 150 | + before_sleep=before_sleep_log(_logger, logging.DEBUG), |
| 151 | + wait=wait_fixed(_DEFAULT_POLL_INTERVAL_S), |
| 152 | + ): |
| 153 | + with attempt: |
| 154 | + job_status = await status( |
| 155 | + rabbitmq_rpc_client, |
| 156 | + rpc_namespace=rpc_namespace, |
| 157 | + job_id=job_id, |
| 158 | + job_id_data=job_id_data, |
| 159 | + ) |
| 160 | + yield job_status |
| 161 | + if not job_status.done: |
| 162 | + msg = f"{job_status.job_id=}: '{job_status.progress=}'" |
| 163 | + raise TryAgain(msg) # noqa: TRY301 |
| 164 | + |
| 165 | + except TryAgain as exc: |
| 166 | + # this is a timeout |
| 167 | + msg = f"Async job {job_id=}, calling to '{method_name}' timed-out after {client_timeout}" |
| 168 | + raise TimeoutError(msg) from exc |
| 169 | + |
| 170 | + |
| 171 | +@dataclass(frozen=True) |
| 172 | +class AsyncJobComposedResult: |
| 173 | + status: AsyncJobStatus |
| 174 | + _result: Awaitable[Any] | None = None |
| 175 | + |
| 176 | + @property |
| 177 | + def done(self) -> bool: |
| 178 | + return self._result is not None |
| 179 | + |
| 180 | + async def result(self) -> Any: |
| 181 | + if not self._result: |
| 182 | + msg = "No result ready!" |
| 183 | + raise ValueError(msg) |
| 184 | + return await self._result |
| 185 | + |
| 186 | + |
| 187 | +async def submit_and_wait( |
| 188 | + rabbitmq_rpc_client: RabbitMQRPCClient, |
| 189 | + *, |
| 190 | + rpc_namespace: RPCNamespace, |
| 191 | + method_name: str, |
| 192 | + job_id_data: AsyncJobNameData, |
| 193 | + client_timeout: datetime.timedelta, |
| 194 | + **kwargs, |
| 195 | +) -> AsyncGenerator[AsyncJobComposedResult, None]: |
| 196 | + async_job_rpc_get = None |
| 197 | + try: |
| 198 | + async_job_rpc_get = await submit( |
| 199 | + rabbitmq_rpc_client, |
| 200 | + rpc_namespace=rpc_namespace, |
| 201 | + method_name=method_name, |
| 202 | + job_id_data=job_id_data, |
| 203 | + **kwargs, |
| 204 | + ) |
| 205 | + job_status: AsyncJobStatus | None = None |
| 206 | + async for job_status in _wait_for_completion( |
| 207 | + rabbitmq_rpc_client, |
| 208 | + rpc_namespace=rpc_namespace, |
| 209 | + method_name=method_name, |
| 210 | + job_id=async_job_rpc_get.job_id, |
| 211 | + job_id_data=job_id_data, |
| 212 | + client_timeout=client_timeout, |
| 213 | + ): |
| 214 | + assert job_status is not None # nosec |
| 215 | + yield AsyncJobComposedResult(job_status) |
| 216 | + if job_status: |
| 217 | + yield AsyncJobComposedResult( |
| 218 | + job_status, |
| 219 | + result( |
| 220 | + rabbitmq_rpc_client, |
| 221 | + rpc_namespace=rpc_namespace, |
| 222 | + job_id=async_job_rpc_get.job_id, |
| 223 | + job_id_data=job_id_data, |
| 224 | + ), |
| 225 | + ) |
| 226 | + except (TimeoutError, CancelledError) as error: |
| 227 | + if async_job_rpc_get is not None: |
| 228 | + try: |
| 229 | + await cancel( |
| 230 | + rabbitmq_rpc_client, |
| 231 | + rpc_namespace=rpc_namespace, |
| 232 | + job_id=async_job_rpc_get.job_id, |
| 233 | + job_id_data=job_id_data, |
| 234 | + ) |
| 235 | + except Exception as exc: |
| 236 | + raise exc from error |
| 237 | + raise |
0 commit comments