Skip to content

✨Async Jobs: add some tests + an internal client #7410

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Mar 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions packages/service-library/src/servicelib/progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,12 @@

@runtime_checkable
class AsyncReportCB(Protocol):
async def __call__(self, report: ProgressReport) -> None:
...
async def __call__(self, report: ProgressReport) -> None: ...


@runtime_checkable
class ReportCB(Protocol):
def __call__(self, report: ProgressReport) -> None:
...
def __call__(self, report: ProgressReport) -> None: ...


def _normalize_weights(steps: int, weights: list[float]) -> list[float]:
Expand Down Expand Up @@ -88,7 +86,7 @@ async def main_fct():
progress_unit: ProgressUnit | None = None
progress_report_cb: AsyncReportCB | ReportCB | None = None
_current_steps: float = _INITIAL_VALUE
_currnet_attempt: int = 0
_current_attempt: int = 0
_children: list["ProgressBarData"] = field(default_factory=list)
_parent: Optional["ProgressBarData"] = None
_continuous_value_lock: asyncio.Lock = field(init=False)
Expand Down Expand Up @@ -148,7 +146,7 @@ async def _report_external(self, value: float) -> None:
# NOTE: here we convert back to actual value since this is possibly weighted
actual_value=value * self.num_steps,
total=self.num_steps,
attempt=self._currnet_attempt,
attempt=self._current_attempt,
unit=self.progress_unit,
message=self.compute_report_message_stuct(),
),
Expand Down Expand Up @@ -200,7 +198,7 @@ async def update(self, steps: float = 1) -> None:
await self._report_external(new_progress_value)

def reset(self) -> None:
self._currnet_attempt += 1
self._current_attempt += 1
self._current_steps = _INITIAL_VALUE
self._last_report_value = _INITIAL_VALUE

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from typing import Final
import datetime
import logging
from asyncio import CancelledError
from collections.abc import AsyncGenerator, Awaitable
from typing import Any, Final

from attr import dataclass
from models_library.api_schemas_rpc_async_jobs.async_jobs import (
AsyncJobGet,
AsyncJobId,
Expand All @@ -9,12 +14,25 @@
)
from models_library.rabbitmq_basic_types import RPCMethodName, RPCNamespace
from pydantic import NonNegativeInt, TypeAdapter
from tenacity import (
AsyncRetrying,
TryAgain,
before_sleep_log,
retry,
retry_if_exception_type,
stop_after_delay,
wait_fixed,
wait_random_exponential,
)

from ....rabbitmq import RemoteMethodNotRegisteredError
from ... import RabbitMQRPCClient

_DEFAULT_TIMEOUT_S: Final[NonNegativeInt] = 30

_RPC_METHOD_NAME_ADAPTER = TypeAdapter(RPCMethodName)
_DEFAULT_POLL_INTERVAL_S: Final[float] = 0.1
_logger = logging.getLogger(__name__)


async def cancel(
Expand Down Expand Up @@ -103,3 +121,117 @@ async def submit(
)
assert isinstance(_result, AsyncJobGet) # nosec
return _result


_DEFAULT_RPC_RETRY_POLICY: dict[str, Any] = {
"retry": retry_if_exception_type(RemoteMethodNotRegisteredError),
"wait": wait_random_exponential(max=20),
"stop": stop_after_delay(60),
"reraise": True,
"before_sleep": before_sleep_log(_logger, logging.INFO),
}


@retry(**_DEFAULT_RPC_RETRY_POLICY)
async def _wait_for_completion(
rabbitmq_rpc_client: RabbitMQRPCClient,
*,
rpc_namespace: RPCNamespace,
method_name: RPCMethodName,
job_id: AsyncJobId,
job_id_data: AsyncJobNameData,
client_timeout: datetime.timedelta,
) -> AsyncGenerator[AsyncJobStatus, None]:
try:
async for attempt in AsyncRetrying(
stop=stop_after_delay(client_timeout.total_seconds()),
reraise=True,
retry=retry_if_exception_type(TryAgain),
before_sleep=before_sleep_log(_logger, logging.DEBUG),
wait=wait_fixed(_DEFAULT_POLL_INTERVAL_S),
):
with attempt:
job_status = await status(
rabbitmq_rpc_client,
rpc_namespace=rpc_namespace,
job_id=job_id,
job_id_data=job_id_data,
)
yield job_status
if not job_status.done:
msg = f"{job_status.job_id=}: '{job_status.progress=}'"
raise TryAgain(msg) # noqa: TRY301

except TryAgain as exc:
# this is a timeout
msg = f"Async job {job_id=}, calling to '{method_name}' timed-out after {client_timeout}"
raise TimeoutError(msg) from exc


@dataclass(frozen=True)
class AsyncJobComposedResult:
status: AsyncJobStatus
_result: Awaitable[Any] | None = None

@property
def done(self) -> bool:
return self._result is not None

async def result(self) -> Any:
if not self._result:
msg = "No result ready!"
raise ValueError(msg)
return await self._result


async def submit_and_wait(
rabbitmq_rpc_client: RabbitMQRPCClient,
*,
rpc_namespace: RPCNamespace,
method_name: str,
job_id_data: AsyncJobNameData,
client_timeout: datetime.timedelta,
**kwargs,
) -> AsyncGenerator[AsyncJobComposedResult, None]:
async_job_rpc_get = None
try:
async_job_rpc_get = await submit(
rabbitmq_rpc_client,
rpc_namespace=rpc_namespace,
method_name=method_name,
job_id_data=job_id_data,
**kwargs,
)
job_status: AsyncJobStatus | None = None
async for job_status in _wait_for_completion(
rabbitmq_rpc_client,
rpc_namespace=rpc_namespace,
method_name=method_name,
job_id=async_job_rpc_get.job_id,
job_id_data=job_id_data,
client_timeout=client_timeout,
):
assert job_status is not None # nosec
yield AsyncJobComposedResult(job_status)
if job_status:
yield AsyncJobComposedResult(
job_status,
result(
rabbitmq_rpc_client,
rpc_namespace=rpc_namespace,
job_id=async_job_rpc_get.job_id,
job_id_data=job_id_data,
),
)
except (TimeoutError, CancelledError) as error:
if async_job_rpc_get is not None:
try:
await cancel(
rabbitmq_rpc_client,
rpc_namespace=rpc_namespace,
job_id=async_job_rpc_get.job_id,
job_id_data=job_id_data,
)
except Exception as exc:
raise exc from error
raise
23 changes: 22 additions & 1 deletion packages/service-library/tests/rabbitmq/conftest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,31 @@
from collections.abc import AsyncIterator, Callable, Coroutine
from collections.abc import AsyncIterator, Awaitable, Callable, Coroutine
from typing import cast

import aiodocker
import arrow
import pytest
from faker import Faker
from models_library.rabbitmq_basic_types import RPCNamespace
from servicelib.rabbitmq._client_rpc import RabbitMQRPCClient


@pytest.fixture
async def rpc_client(
rabbitmq_rpc_client: Callable[[str], Awaitable[RabbitMQRPCClient]],
) -> RabbitMQRPCClient:
return await rabbitmq_rpc_client("pytest_rpc_client")


@pytest.fixture
async def rpc_server(
rabbitmq_rpc_client: Callable[[str], Awaitable[RabbitMQRPCClient]],
) -> RabbitMQRPCClient:
return await rabbitmq_rpc_client("pytest_rpc_server")


@pytest.fixture
def namespace() -> RPCNamespace:
return RPCNamespace.from_entries({f"test{i}": f"test{i}" for i in range(8)})


@pytest.fixture(autouse=True)
Expand Down
21 changes: 1 addition & 20 deletions packages/service-library/tests/rabbitmq/test_rabbitmq_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# pylint:disable=unused-argument

import asyncio
from collections.abc import Awaitable, Callable
from collections.abc import Awaitable
from typing import Any, Final

import pytest
Expand All @@ -23,11 +23,6 @@
MULTIPLE_REQUESTS_COUNT: Final[NonNegativeInt] = 100


@pytest.fixture
def namespace() -> RPCNamespace:
return RPCNamespace.from_entries({f"test{i}": f"test{i}" for i in range(8)})


async def add_me(*, x: Any, y: Any) -> Any:
return x + y
# NOTE: types are not enforced
Expand All @@ -49,20 +44,6 @@ def __add__(self, other: "CustomClass") -> "CustomClass":
return CustomClass(x=self.x + other.x, y=self.y + other.y)


@pytest.fixture
async def rpc_client(
rabbitmq_rpc_client: Callable[[str], Awaitable[RabbitMQRPCClient]],
) -> RabbitMQRPCClient:
return await rabbitmq_rpc_client("pytest_rpc_client")


@pytest.fixture
async def rpc_server(
rabbitmq_rpc_client: Callable[[str], Awaitable[RabbitMQRPCClient]],
) -> RabbitMQRPCClient:
return await rabbitmq_rpc_client("pytest_rpc_server")


@pytest.mark.parametrize(
"x,y,expected_result,expected_type",
[
Expand Down
Loading
Loading