Skip to content

Commit 445c6de

Browse files
authored
✨Async Jobs: add some tests + an internal client (#7410)
1 parent 180355d commit 445c6de

File tree

9 files changed

+492
-38
lines changed

9 files changed

+492
-38
lines changed

packages/service-library/src/servicelib/progress_bar.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,12 @@
2222

2323
@runtime_checkable
2424
class AsyncReportCB(Protocol):
25-
async def __call__(self, report: ProgressReport) -> None:
26-
...
25+
async def __call__(self, report: ProgressReport) -> None: ...
2726

2827

2928
@runtime_checkable
3029
class ReportCB(Protocol):
31-
def __call__(self, report: ProgressReport) -> None:
32-
...
30+
def __call__(self, report: ProgressReport) -> None: ...
3331

3432

3533
def _normalize_weights(steps: int, weights: list[float]) -> list[float]:
@@ -88,7 +86,7 @@ async def main_fct():
8886
progress_unit: ProgressUnit | None = None
8987
progress_report_cb: AsyncReportCB | ReportCB | None = None
9088
_current_steps: float = _INITIAL_VALUE
91-
_currnet_attempt: int = 0
89+
_current_attempt: int = 0
9290
_children: list["ProgressBarData"] = field(default_factory=list)
9391
_parent: Optional["ProgressBarData"] = None
9492
_continuous_value_lock: asyncio.Lock = field(init=False)
@@ -148,7 +146,7 @@ async def _report_external(self, value: float) -> None:
148146
# NOTE: here we convert back to actual value since this is possibly weighted
149147
actual_value=value * self.num_steps,
150148
total=self.num_steps,
151-
attempt=self._currnet_attempt,
149+
attempt=self._current_attempt,
152150
unit=self.progress_unit,
153151
message=self.compute_report_message_stuct(),
154152
),
@@ -200,7 +198,7 @@ async def update(self, steps: float = 1) -> None:
200198
await self._report_external(new_progress_value)
201199

202200
def reset(self) -> None:
203-
self._currnet_attempt += 1
201+
self._current_attempt += 1
204202
self._current_steps = _INITIAL_VALUE
205203
self._last_report_value = _INITIAL_VALUE
206204

packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/async_jobs/async_jobs.py

+133-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
1-
from typing import Final
1+
import datetime
2+
import logging
3+
from asyncio import CancelledError
4+
from collections.abc import AsyncGenerator, Awaitable
5+
from typing import Any, Final
26

7+
from attr import dataclass
38
from models_library.api_schemas_rpc_async_jobs.async_jobs import (
49
AsyncJobGet,
510
AsyncJobId,
@@ -9,12 +14,25 @@
914
)
1015
from models_library.rabbitmq_basic_types import RPCMethodName, RPCNamespace
1116
from pydantic import NonNegativeInt, TypeAdapter
17+
from tenacity import (
18+
AsyncRetrying,
19+
TryAgain,
20+
before_sleep_log,
21+
retry,
22+
retry_if_exception_type,
23+
stop_after_delay,
24+
wait_fixed,
25+
wait_random_exponential,
26+
)
1227

28+
from ....rabbitmq import RemoteMethodNotRegisteredError
1329
from ... import RabbitMQRPCClient
1430

1531
_DEFAULT_TIMEOUT_S: Final[NonNegativeInt] = 30
1632

1733
_RPC_METHOD_NAME_ADAPTER = TypeAdapter(RPCMethodName)
34+
_DEFAULT_POLL_INTERVAL_S: Final[float] = 0.1
35+
_logger = logging.getLogger(__name__)
1836

1937

2038
async def cancel(
@@ -103,3 +121,117 @@ async def submit(
103121
)
104122
assert isinstance(_result, AsyncJobGet) # nosec
105123
return _result
124+
125+
126+
_DEFAULT_RPC_RETRY_POLICY: dict[str, Any] = {
127+
"retry": retry_if_exception_type(RemoteMethodNotRegisteredError),
128+
"wait": wait_random_exponential(max=20),
129+
"stop": stop_after_delay(60),
130+
"reraise": True,
131+
"before_sleep": before_sleep_log(_logger, logging.INFO),
132+
}
133+
134+
135+
@retry(**_DEFAULT_RPC_RETRY_POLICY)
136+
async def _wait_for_completion(
137+
rabbitmq_rpc_client: RabbitMQRPCClient,
138+
*,
139+
rpc_namespace: RPCNamespace,
140+
method_name: RPCMethodName,
141+
job_id: AsyncJobId,
142+
job_id_data: AsyncJobNameData,
143+
client_timeout: datetime.timedelta,
144+
) -> AsyncGenerator[AsyncJobStatus, None]:
145+
try:
146+
async for attempt in AsyncRetrying(
147+
stop=stop_after_delay(client_timeout.total_seconds()),
148+
reraise=True,
149+
retry=retry_if_exception_type(TryAgain),
150+
before_sleep=before_sleep_log(_logger, logging.DEBUG),
151+
wait=wait_fixed(_DEFAULT_POLL_INTERVAL_S),
152+
):
153+
with attempt:
154+
job_status = await status(
155+
rabbitmq_rpc_client,
156+
rpc_namespace=rpc_namespace,
157+
job_id=job_id,
158+
job_id_data=job_id_data,
159+
)
160+
yield job_status
161+
if not job_status.done:
162+
msg = f"{job_status.job_id=}: '{job_status.progress=}'"
163+
raise TryAgain(msg) # noqa: TRY301
164+
165+
except TryAgain as exc:
166+
# this is a timeout
167+
msg = f"Async job {job_id=}, calling to '{method_name}' timed-out after {client_timeout}"
168+
raise TimeoutError(msg) from exc
169+
170+
171+
@dataclass(frozen=True)
172+
class AsyncJobComposedResult:
173+
status: AsyncJobStatus
174+
_result: Awaitable[Any] | None = None
175+
176+
@property
177+
def done(self) -> bool:
178+
return self._result is not None
179+
180+
async def result(self) -> Any:
181+
if not self._result:
182+
msg = "No result ready!"
183+
raise ValueError(msg)
184+
return await self._result
185+
186+
187+
async def submit_and_wait(
188+
rabbitmq_rpc_client: RabbitMQRPCClient,
189+
*,
190+
rpc_namespace: RPCNamespace,
191+
method_name: str,
192+
job_id_data: AsyncJobNameData,
193+
client_timeout: datetime.timedelta,
194+
**kwargs,
195+
) -> AsyncGenerator[AsyncJobComposedResult, None]:
196+
async_job_rpc_get = None
197+
try:
198+
async_job_rpc_get = await submit(
199+
rabbitmq_rpc_client,
200+
rpc_namespace=rpc_namespace,
201+
method_name=method_name,
202+
job_id_data=job_id_data,
203+
**kwargs,
204+
)
205+
job_status: AsyncJobStatus | None = None
206+
async for job_status in _wait_for_completion(
207+
rabbitmq_rpc_client,
208+
rpc_namespace=rpc_namespace,
209+
method_name=method_name,
210+
job_id=async_job_rpc_get.job_id,
211+
job_id_data=job_id_data,
212+
client_timeout=client_timeout,
213+
):
214+
assert job_status is not None # nosec
215+
yield AsyncJobComposedResult(job_status)
216+
if job_status:
217+
yield AsyncJobComposedResult(
218+
job_status,
219+
result(
220+
rabbitmq_rpc_client,
221+
rpc_namespace=rpc_namespace,
222+
job_id=async_job_rpc_get.job_id,
223+
job_id_data=job_id_data,
224+
),
225+
)
226+
except (TimeoutError, CancelledError) as error:
227+
if async_job_rpc_get is not None:
228+
try:
229+
await cancel(
230+
rabbitmq_rpc_client,
231+
rpc_namespace=rpc_namespace,
232+
job_id=async_job_rpc_get.job_id,
233+
job_id_data=job_id_data,
234+
)
235+
except Exception as exc:
236+
raise exc from error
237+
raise

packages/service-library/tests/rabbitmq/conftest.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,31 @@
1-
from collections.abc import AsyncIterator, Callable, Coroutine
1+
from collections.abc import AsyncIterator, Awaitable, Callable, Coroutine
22
from typing import cast
33

44
import aiodocker
55
import arrow
66
import pytest
77
from faker import Faker
8+
from models_library.rabbitmq_basic_types import RPCNamespace
9+
from servicelib.rabbitmq._client_rpc import RabbitMQRPCClient
10+
11+
12+
@pytest.fixture
13+
async def rpc_client(
14+
rabbitmq_rpc_client: Callable[[str], Awaitable[RabbitMQRPCClient]],
15+
) -> RabbitMQRPCClient:
16+
return await rabbitmq_rpc_client("pytest_rpc_client")
17+
18+
19+
@pytest.fixture
20+
async def rpc_server(
21+
rabbitmq_rpc_client: Callable[[str], Awaitable[RabbitMQRPCClient]],
22+
) -> RabbitMQRPCClient:
23+
return await rabbitmq_rpc_client("pytest_rpc_server")
24+
25+
26+
@pytest.fixture
27+
def namespace() -> RPCNamespace:
28+
return RPCNamespace.from_entries({f"test{i}": f"test{i}" for i in range(8)})
829

930

1031
@pytest.fixture(autouse=True)

packages/service-library/tests/rabbitmq/test_rabbitmq_rpc.py

+1-20
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# pylint:disable=unused-argument
33

44
import asyncio
5-
from collections.abc import Awaitable, Callable
5+
from collections.abc import Awaitable
66
from typing import Any, Final
77

88
import pytest
@@ -23,11 +23,6 @@
2323
MULTIPLE_REQUESTS_COUNT: Final[NonNegativeInt] = 100
2424

2525

26-
@pytest.fixture
27-
def namespace() -> RPCNamespace:
28-
return RPCNamespace.from_entries({f"test{i}": f"test{i}" for i in range(8)})
29-
30-
3126
async def add_me(*, x: Any, y: Any) -> Any:
3227
return x + y
3328
# NOTE: types are not enforced
@@ -49,20 +44,6 @@ def __add__(self, other: "CustomClass") -> "CustomClass":
4944
return CustomClass(x=self.x + other.x, y=self.y + other.y)
5045

5146

52-
@pytest.fixture
53-
async def rpc_client(
54-
rabbitmq_rpc_client: Callable[[str], Awaitable[RabbitMQRPCClient]],
55-
) -> RabbitMQRPCClient:
56-
return await rabbitmq_rpc_client("pytest_rpc_client")
57-
58-
59-
@pytest.fixture
60-
async def rpc_server(
61-
rabbitmq_rpc_client: Callable[[str], Awaitable[RabbitMQRPCClient]],
62-
) -> RabbitMQRPCClient:
63-
return await rabbitmq_rpc_client("pytest_rpc_server")
64-
65-
6647
@pytest.mark.parametrize(
6748
"x,y,expected_result,expected_type",
6849
[

0 commit comments

Comments
 (0)