Skip to content

✨RabbitMQ: Add RPC decorator for server side #4682

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Aug 30, 2023
1 change: 1 addition & 0 deletions .github/workflows/ci-testing-deploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1265,6 +1265,7 @@ jobs:
- name: typecheck
run: ./ci/github/unit-testing/service-library.bash typecheck
- name: test
if: always()
run: ./ci/github/unit-testing/service-library.bash test_all
- uses: codecov/[email protected]
with:
Expand Down
1 change: 1 addition & 0 deletions packages/service-library/setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ asyncio_mode = auto
markers =
testit: "marks test to run during development"
performance_test: "performance test"
no_cleanup_check_rabbitmq_server_has_no_errors: "no check in rabbitmq logs"
11 changes: 6 additions & 5 deletions packages/service-library/src/servicelib/background_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import contextlib
import datetime
import logging
from typing import AsyncIterator, Awaitable, Callable, Final
from collections.abc import AsyncIterator, Awaitable, Callable
from typing import Final

from pydantic.errors import PydanticErrorMixin
from servicelib.logging_utils import log_catch, log_context
Expand Down Expand Up @@ -34,12 +35,12 @@ async def _periodic_scheduled_task(
with attempt:
with log_context(
logger,
logging.DEBUG,
logging.INFO,
msg=f"iteration {attempt.retry_state.attempt_number} of '{task_name}'",
), log_catch(logger):
await task(**task_kwargs)

raise TryAgain()
raise TryAgain


def start_periodic_task(
Expand All @@ -50,7 +51,7 @@ def start_periodic_task(
**kwargs,
) -> asyncio.Task:
with log_context(
logger, logging.INFO, msg=f"create periodic background task '{task_name}'"
logger, logging.DEBUG, msg=f"create periodic background task '{task_name}'"
):
return asyncio.create_task(
_periodic_scheduled_task(
Expand Down Expand Up @@ -97,7 +98,7 @@ async def stop_periodic_task(
) -> None:
with log_context(
logger,
logging.INFO,
logging.DEBUG,
msg=f"cancel periodic background task '{asyncio_task.get_name()}'",
):
await cancel_task(asyncio_task, timeout=timeout)
Expand Down
7 changes: 4 additions & 3 deletions packages/service-library/src/servicelib/logging_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,12 +227,13 @@ def log_catch(logger: logging.Logger, reraise: bool = True):
logger.debug("call was cancelled")
raise
except Exception as exc: # pylint: disable=broad-except
logger.error("Unhandled exception: %s", f"{exc}", exc_info=True)
logger.exception("Unhandled exception:")
if reraise:
raise exc from exc


un_capitalize = lambda s: s[:1].lower() + s[1:] if s else ""
def _un_capitalize(s):
return s[:1].lower() + s[1:] if s else ""


@contextmanager
Expand All @@ -246,7 +247,7 @@ def log_context(
):
# NOTE: preserves original signature https://docs.python.org/3/library/logging.html#logging.Logger.log
start = datetime.now() # noqa: DTZ005
msg = un_capitalize(msg.strip())
msg = _un_capitalize(msg.strip())

kwargs: dict[str, Any] = {}
if extra:
Expand Down
29 changes: 25 additions & 4 deletions packages/service-library/src/servicelib/rabbitmq.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import functools
import logging
from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field
Expand All @@ -12,6 +13,7 @@
from settings_library.rabbit import RabbitSettings

from .rabbitmq_errors import RemoteMethodNotRegisteredError, RPCNotInitializedError
from .rabbitmq_rpc_router import RPCRouter
from .rabbitmq_utils import (
RPCMethodName,
RPCNamespace,
Expand All @@ -36,8 +38,9 @@ def routing_key(self) -> str | None:
...


_DEFAULT_RABBITMQ_SERVER_HEARTBEAT_S = 60
_DEFAULT_PREFETCH_VALUE = 10
_DEFAULT_RABBITMQ_SERVER_HEARTBEAT_S: Final[int] = 60
_DEFAULT_PREFETCH_VALUE: Final[int] = 10
_DEFAULT_RABBITMQ_EXECUTION_TIMEOUT_S: Final[int] = 5


@dataclass
Expand Down Expand Up @@ -82,7 +85,7 @@ def _connection_close_callback(

def _channel_close_callback(
self,
sender: Any, # pylint: disable=unused-argument
sender: Any, # pylint: disable=unused-argument # noqa: ARG002
exc: BaseException | None,
) -> None:
if exc:
Expand All @@ -108,6 +111,7 @@ async def _get_connection(
connection = await aio_pika.connect_robust(
url,
client_properties={"connection_name": connection_name},
timeout=_DEFAULT_RABBITMQ_EXECUTION_TIMEOUT_S,
)
connection.close_callbacks.add(self._connection_close_callback)
return connection
Expand Down Expand Up @@ -197,6 +201,7 @@ async def subscribe(
if topics is None
else aio_pika.ExchangeType.TOPIC,
durable=True,
timeout=_DEFAULT_RABBITMQ_EXECUTION_TIMEOUT_S,
)

# NOTE: durable=True makes the queue persistent between RabbitMQ restarts/crashes
Expand Down Expand Up @@ -295,13 +300,15 @@ async def publish(self, exchange_name: str, message: RabbitMessage) -> None:
"""
assert self._channel_pool # nosec
topic = message.routing_key()

async with self._channel_pool.acquire() as channel:
exchange = await channel.declare_exchange(
exchange_name,
aio_pika.ExchangeType.FANOUT
if topic is None
else aio_pika.ExchangeType.TOPIC,
durable=True,
timeout=_DEFAULT_RABBITMQ_EXECUTION_TIMEOUT_S,
)
await exchange.publish(
aio_pika.Message(message.body()),
Expand Down Expand Up @@ -346,7 +353,7 @@ async def rpc_request(
raise RemoteMethodNotRegisteredError(
method_name=namespaced_method_name, incoming_message=e.args[1]
) from e
raise e
raise

async def rpc_register_handler(
self,
Expand All @@ -371,6 +378,20 @@ async def rpc_register_handler(
auto_delete=True,
)

async def rpc_register_router(
self,
router: RPCRouter,
namespace: RPCNamespace,
*handler_args,
**handler_kwargs,
) -> None:
for rpc_method_name, handler in router.routes.items():
await self.rpc_register_handler(
namespace,
rpc_method_name,
functools.partial(handler, *handler_args, **handler_kwargs),
)

async def rpc_unregister_handler(self, handler: Callable[..., Any]) -> None:
"""Unbind a locally added `handler`"""

Expand Down
46 changes: 46 additions & 0 deletions packages/service-library/src/servicelib/rabbitmq_rpc_router.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import functools
import logging
from collections.abc import Callable
from dataclasses import dataclass, field
from typing import Any, TypeVar

import orjson
from models_library.utils.fastapi_encoders import jsonable_encoder
from pydantic import SecretStr
from servicelib.logging_utils import log_catch, log_context
from servicelib.rabbitmq_utils import RPCMethodName

DecoratedCallable = TypeVar("DecoratedCallable", bound=Callable[..., Any])

_logger = logging.getLogger("rpc.access")

_RPC_CUSTOM_ENCODER: dict[Any, Callable[[Any], Any]] = {
SecretStr: SecretStr.get_secret_value
}


@dataclass
class RPCRouter:
routes: dict[RPCMethodName, Callable] = field(default_factory=dict)

def expose(self) -> Callable[[DecoratedCallable], DecoratedCallable]:
def decorator(func: DecoratedCallable) -> DecoratedCallable:
@functools.wraps(func)
async def wrapper(*args, **kwargs):
with log_context(
_logger,
logging.INFO,
msg=f"calling {func.__name__} with {args}, {kwargs}",
), log_catch(_logger, reraise=True):
result = await func(*args, **kwargs)
return orjson.dumps(
jsonable_encoder(
result,
custom_encoder=_RPC_CUSTOM_ENCODER,
)
)

self.routes[RPCMethodName(func.__name__)] = wrapper
return func

return decorator
36 changes: 27 additions & 9 deletions packages/service-library/tests/rabbitmq/conftest.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
import datetime
import time
from typing import AsyncIterator, Coroutine, cast
from collections.abc import AsyncIterator, Callable, Coroutine
from typing import cast

import aiodocker
import arrow
import pytest
from faker import Faker


@pytest.fixture
async def cleanup_check_rabbitmq_server_has_no_errors() -> AsyncIterator[None]:
now = datetime.datetime.now()
@pytest.fixture(autouse=True)
async def cleanup_check_rabbitmq_server_has_no_errors(
request: pytest.FixtureRequest,
) -> AsyncIterator[None]:
now = arrow.utcnow()
yield
if "no_cleanup_check_rabbitmq_server_has_no_errors" in request.keywords:
return
print("--> checking for errors/warnings in rabbitmq logs...")
async with aiodocker.Docker() as docker_client:
containers = await docker_client.containers.list(filters=({"name": ["rabbit"]}))
Expand All @@ -22,7 +27,7 @@ async def cleanup_check_rabbitmq_server_has_no_errors() -> AsyncIterator[None]:
stdout=True,
stderr=True,
follow=False,
since=time.mktime(now.timetuple()),
since=now.timestamp(),
),
)

Expand All @@ -36,6 +41,19 @@ async def cleanup_check_rabbitmq_server_has_no_errors() -> AsyncIterator[None]:
for log in warning_logs
if all(w not in log for w in RABBIT_SKIPPED_WARNINGS)
]
assert not filtered_warning_logs
assert not error_logs
assert (
not filtered_warning_logs
), f"warning(s) found in rabbitmq logs for {request.function}"
assert not error_logs, f"error(s) found in rabbitmq logs for {request.function}"
print("<-- no error founds in rabbitmq server logs, that's great. good job!")


@pytest.fixture
def random_exchange_name() -> Callable[[], str]:
def _creator() -> str:
faker = (
Faker()
) # NOTE: this ensure the faker seed is new each time, since we do not clean the exchanges
return f"pytest_fake_exchange_{faker.pystr()}"

return _creator
20 changes: 2 additions & 18 deletions packages/service-library/tests/rabbitmq/test_rabbitmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def rabbit_client_name(faker: Faker) -> str:
async def test_rabbit_client(
rabbit_client_name: str,
rabbit_service: RabbitSettings,
cleanup_check_rabbitmq_server_has_no_errors: None,
):
client = RabbitMQClient(rabbit_client_name, rabbit_service)
assert client
Expand All @@ -51,14 +50,6 @@ async def test_rabbit_client(
assert client._connection_pool.is_closed # noqa: SLF001


@pytest.fixture
def random_exchange_name(faker: Faker) -> Callable[[], str]:
def _creator() -> str:
return f"pytest_fake_exchange_{faker.pystr()}"

return _creator


@pytest.fixture
def mocked_message_parser(mocker: MockerFixture) -> mock.AsyncMock:
return mocker.AsyncMock(return_value=True)
Expand Down Expand Up @@ -120,7 +111,6 @@ async def _assert_message_received(


async def test_rabbit_client_pub_sub_message_is_lost_if_no_consumer_present(
cleanup_check_rabbitmq_server_has_no_errors: None,
rabbitmq_client: Callable[[str], RabbitMQClient],
random_exchange_name: Callable[[], str],
mocked_message_parser: mock.AsyncMock,
Expand All @@ -138,7 +128,6 @@ async def test_rabbit_client_pub_sub_message_is_lost_if_no_consumer_present(


async def test_rabbit_client_pub_sub(
cleanup_check_rabbitmq_server_has_no_errors: None,
rabbitmq_client: Callable[[str], RabbitMQClient],
random_exchange_name: Callable[[], str],
mocked_message_parser: mock.AsyncMock,
Expand All @@ -156,7 +145,6 @@ async def test_rabbit_client_pub_sub(

@pytest.mark.parametrize("num_subs", [10])
async def test_rabbit_client_pub_many_subs(
cleanup_check_rabbitmq_server_has_no_errors: None,
rabbitmq_client: Callable[[str], RabbitMQClient],
random_exchange_name: Callable[[], str],
mocker: MockerFixture,
Expand Down Expand Up @@ -188,7 +176,6 @@ async def test_rabbit_client_pub_many_subs(


async def test_rabbit_client_pub_sub_republishes_if_exception_raised(
cleanup_check_rabbitmq_server_has_no_errors: None,
rabbitmq_client: Callable[[str], RabbitMQClient],
random_exchange_name: Callable[[], str],
mocked_message_parser: mock.AsyncMock,
Expand Down Expand Up @@ -219,7 +206,6 @@ def _raise_once_then_true(*args, **kwargs):

@pytest.mark.parametrize("num_subs", [10])
async def test_pub_sub_with_non_exclusive_queue(
cleanup_check_rabbitmq_server_has_no_errors: None,
rabbitmq_client: Callable[[str], RabbitMQClient],
random_exchange_name: Callable[[], str],
mocker: MockerFixture,
Expand Down Expand Up @@ -257,7 +243,6 @@ async def test_pub_sub_with_non_exclusive_queue(


def test_rabbit_pub_sub_performance(
cleanup_check_rabbitmq_server_has_no_errors: None,
benchmark,
rabbitmq_client: Callable[[str], RabbitMQClient],
random_exchange_name: Callable[[], str],
Expand Down Expand Up @@ -285,7 +270,6 @@ def run_test_async():


async def test_rabbit_pub_sub_with_topic(
cleanup_check_rabbitmq_server_has_no_errors: None,
rabbitmq_client: Callable[[str], RabbitMQClient],
random_exchange_name: Callable[[], str],
mocker: MockerFixture,
Expand Down Expand Up @@ -338,7 +322,6 @@ async def test_rabbit_pub_sub_with_topic(


async def test_rabbit_pub_sub_bind_and_unbind_topics(
cleanup_check_rabbitmq_server_has_no_errors: None,
rabbitmq_client: Callable[[str], RabbitMQClient],
random_exchange_name: Callable[[], str],
mocked_message_parser: mock.AsyncMock,
Expand Down Expand Up @@ -408,8 +391,8 @@ async def test_rabbit_pub_sub_bind_and_unbind_topics(
await _assert_message_received(mocked_message_parser, 0)


@pytest.mark.no_cleanup_check_rabbitmq_server_has_no_errors()
async def test_rabbit_adding_topics_to_a_fanout_exchange(
cleanup_check_rabbitmq_server_has_no_errors: None,
rabbitmq_client: Callable[[str], RabbitMQClient],
random_exchange_name: Callable[[], str],
mocked_message_parser: mock.AsyncMock,
Expand Down Expand Up @@ -439,6 +422,7 @@ async def test_rabbit_adding_topics_to_a_fanout_exchange(
await _assert_message_received(mocked_message_parser, 0)


@pytest.mark.no_cleanup_check_rabbitmq_server_has_no_errors()
async def test_rabbit_not_using_the_same_exchange_type_raises(
rabbitmq_client: Callable[[str], RabbitMQClient],
random_exchange_name: Callable[[], str],
Expand Down
Loading