Skip to content

Commit ed1c032

Browse files
authored
🎨 Add Reusable Lifespan Contexts for RabbitMQ and Redis in servicelib.fastapi (#7547)
1 parent 64f94c8 commit ed1c032

File tree

12 files changed

+572
-11
lines changed

12 files changed

+572
-11
lines changed
Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,77 @@
1+
import contextlib
2+
from collections.abc import Iterator
3+
from typing import Final
4+
15
from common_library.errors_classes import OsparcErrorMixin
6+
from fastapi import FastAPI
7+
from fastapi_lifespan_manager import State
8+
9+
from ..logging_utils import log_context
210

311

412
class LifespanError(OsparcErrorMixin, RuntimeError): ...
513

614

715
class LifespanOnStartupError(LifespanError):
8-
msg_template = "Failed during startup of {module}"
16+
msg_template = "Failed during startup of {lifespan_name}"
917

1018

1119
class LifespanOnShutdownError(LifespanError):
12-
msg_template = "Failed during shutdown of {module}"
20+
msg_template = "Failed during shutdown of {lifespan_name}"
21+
22+
23+
class LifespanAlreadyCalledError(LifespanError):
24+
msg_template = "The lifespan '{lifespan_name}' has already been called."
25+
26+
27+
class LifespanExpectedCalledError(LifespanError):
28+
msg_template = "The lifespan '{lifespan_name}' was not called. Ensure it is properly configured and invoked."
29+
30+
31+
_CALLED_LIFESPANS_KEY: Final[str] = "_CALLED_LIFESPANS"
32+
33+
34+
def is_lifespan_called(state: State, lifespan_name: str) -> bool:
35+
# NOTE: This assert is meant to catch a common mistake:
36+
# The `lifespan` function should accept up to two *optional* positional arguments: (app: FastAPI, state: State).
37+
# Valid signatures include: `()`, `(app)`, `(app, state)`, or even `(_, state)`.
38+
# It's easy to accidentally swap or misplace these arguments.
39+
assert not isinstance( # nosec
40+
state, FastAPI
41+
), "Did you swap arguments? `lifespan(app, state)` expects (app: FastAPI, state: State)"
42+
43+
called_lifespans = state.get(_CALLED_LIFESPANS_KEY, set())
44+
return lifespan_name in called_lifespans
45+
46+
47+
def mark_lifespace_called(state: State, lifespan_name: str) -> State:
48+
"""Validates if a lifespan has already been called and records it in the state.
49+
Raises LifespanAlreadyCalledError if the lifespan has already been called.
50+
"""
51+
if is_lifespan_called(state, lifespan_name):
52+
raise LifespanAlreadyCalledError(lifespan_name=lifespan_name)
53+
54+
called_lifespans = state.get(_CALLED_LIFESPANS_KEY, set())
55+
called_lifespans.add(lifespan_name)
56+
return {_CALLED_LIFESPANS_KEY: called_lifespans}
57+
58+
59+
def ensure_lifespan_called(state: State, lifespan_name: str) -> None:
60+
"""Ensures that a lifespan has been called.
61+
Raises LifespanNotCalledError if the lifespan has not been called.
62+
"""
63+
if not is_lifespan_called(state, lifespan_name):
64+
raise LifespanExpectedCalledError(lifespan_name=lifespan_name)
65+
66+
67+
@contextlib.contextmanager
68+
def lifespan_context(
69+
logger, level, lifespan_name: str, state: State
70+
) -> Iterator[State]:
71+
"""Helper context manager to log lifespan event and mark lifespan as called."""
72+
73+
with log_context(logger, level, lifespan_name):
74+
# Check if lifespan has already been called
75+
called_state = mark_lifespace_called(state, lifespan_name)
76+
77+
yield called_state

packages/service-library/src/servicelib/fastapi/postgres_lifespan.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55

66
from fastapi import FastAPI
77
from fastapi_lifespan_manager import State
8-
from servicelib.logging_utils import log_catch, log_context
98
from settings_library.postgres import PostgresSettings
109
from sqlalchemy.ext.asyncio import AsyncEngine
1110

1211
from ..db_asyncpg_utils import create_async_engine_and_database_ready
13-
from .lifespan_utils import LifespanOnStartupError
12+
from ..logging_utils import log_catch
13+
from .lifespan_utils import LifespanOnStartupError, lifespan_context
1414

1515
_logger = logging.getLogger(__name__)
1616

@@ -30,8 +30,10 @@ def create_postgres_database_input_state(settings: PostgresSettings) -> State:
3030

3131
async def postgres_database_lifespan(_: FastAPI, state: State) -> AsyncIterator[State]:
3232

33-
with log_context(_logger, logging.INFO, f"{__name__}"):
33+
_lifespan_name = f"{__name__}.{postgres_database_lifespan.__name__}"
3434

35+
with lifespan_context(_logger, logging.INFO, _lifespan_name, state) as called_state:
36+
# Validate input state
3537
settings = state[PostgresLifespanState.POSTGRES_SETTINGS]
3638

3739
if settings is None or not isinstance(settings, PostgresSettings):
@@ -48,6 +50,7 @@ async def postgres_database_lifespan(_: FastAPI, state: State) -> AsyncIterator[
4850

4951
yield {
5052
PostgresLifespanState.POSTGRES_ASYNC_ENGINE: async_engine,
53+
**called_state,
5154
}
5255

5356
finally:

packages/service-library/src/servicelib/fastapi/rabbitmq.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
import warnings
23

34
from fastapi import FastAPI
45
from models_library.rabbitmq_messages import RabbitMessageBase
@@ -55,6 +56,13 @@ def setup_rabbit(
5556
settings -- Rabbit settings or if None, the connection to rabbit is not done upon startup
5657
name -- name for the rmq client name
5758
"""
59+
warnings.warn(
60+
"The 'setup_rabbit' function is deprecated and will be removed in a future release. "
61+
"Please use 'rabbitmq_lifespan' for managing RabbitMQ connections.",
62+
DeprecationWarning,
63+
stacklevel=2,
64+
)
65+
5866
app.state.rabbitmq_client = None # RabbitMQClient | None
5967
app.state.rabbitmq_client_name = name
6068
app.state.rabbitmq_settings = settings
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import logging
2+
from collections.abc import AsyncIterator
3+
4+
from fastapi import FastAPI
5+
from fastapi_lifespan_manager import State
6+
from pydantic import BaseModel, ValidationError
7+
from settings_library.rabbit import RabbitSettings
8+
9+
from ..rabbitmq import wait_till_rabbitmq_responsive
10+
from .lifespan_utils import (
11+
LifespanOnStartupError,
12+
lifespan_context,
13+
)
14+
15+
_logger = logging.getLogger(__name__)
16+
17+
18+
class RabbitMQConfigurationError(LifespanOnStartupError):
19+
msg_template = "Invalid RabbitMQ config on startup : {validation_error}"
20+
21+
22+
class RabbitMQLifespanState(BaseModel):
23+
RABBIT_SETTINGS: RabbitSettings
24+
25+
26+
async def rabbitmq_connectivity_lifespan(
27+
_: FastAPI, state: State
28+
) -> AsyncIterator[State]:
29+
"""Ensures RabbitMQ connectivity during lifespan.
30+
31+
For creating clients, use additional lifespans like rabbitmq_rpc_client_context.
32+
"""
33+
_lifespan_name = f"{__name__}.{rabbitmq_connectivity_lifespan.__name__}"
34+
35+
with lifespan_context(_logger, logging.INFO, _lifespan_name, state) as called_state:
36+
37+
# Validate input state
38+
try:
39+
rabbit_state = RabbitMQLifespanState.model_validate(state)
40+
rabbit_dsn_with_secrets = rabbit_state.RABBIT_SETTINGS.dsn
41+
except ValidationError as exc:
42+
raise RabbitMQConfigurationError(validation_error=exc, state=state) from exc
43+
44+
# Wait for RabbitMQ to be responsive
45+
await wait_till_rabbitmq_responsive(rabbit_dsn_with_secrets)
46+
47+
yield {"RABBIT_CONNECTIVITY_LIFESPAN_NAME": _lifespan_name, **called_state}
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import asyncio
2+
import logging
3+
from collections.abc import AsyncIterator
4+
from typing import Annotated
5+
6+
from fastapi import FastAPI
7+
from fastapi_lifespan_manager import State
8+
from pydantic import BaseModel, StringConstraints, ValidationError
9+
from settings_library.redis import RedisDatabase, RedisSettings
10+
11+
from ..logging_utils import log_catch, log_context
12+
from ..redis import RedisClientSDK
13+
from .lifespan_utils import LifespanOnStartupError, lifespan_context
14+
15+
_logger = logging.getLogger(__name__)
16+
17+
18+
class RedisConfigurationError(LifespanOnStartupError):
19+
msg_template = "Invalid redis config on startup : {validation_error}"
20+
21+
22+
class RedisLifespanState(BaseModel):
23+
REDIS_SETTINGS: RedisSettings
24+
REDIS_CLIENT_NAME: Annotated[str, StringConstraints(min_length=3, max_length=32)]
25+
REDIS_CLIENT_DB: RedisDatabase
26+
27+
28+
async def redis_client_sdk_lifespan(_: FastAPI, state: State) -> AsyncIterator[State]:
29+
_lifespan_name = f"{__name__}.{redis_client_sdk_lifespan.__name__}"
30+
31+
with lifespan_context(_logger, logging.INFO, _lifespan_name, state) as called_state:
32+
33+
# Validate input state
34+
try:
35+
redis_state = RedisLifespanState.model_validate(state)
36+
redis_dsn_with_secrets = redis_state.REDIS_SETTINGS.build_redis_dsn(
37+
redis_state.REDIS_CLIENT_DB
38+
)
39+
except ValidationError as exc:
40+
raise RedisConfigurationError(validation_error=exc, state=state) from exc
41+
42+
# Setup client
43+
with log_context(
44+
_logger,
45+
logging.INFO,
46+
f"Creating redis client with name={redis_state.REDIS_CLIENT_NAME}",
47+
):
48+
# NOTE: sdk integrats waiting until connection is ready
49+
# and will raise an exception if it cannot connect
50+
redis_client = RedisClientSDK(
51+
redis_dsn_with_secrets,
52+
client_name=redis_state.REDIS_CLIENT_NAME,
53+
)
54+
55+
try:
56+
yield {"REDIS_CLIENT_SDK": redis_client, **called_state}
57+
finally:
58+
# Teardown client
59+
with log_catch(_logger, reraise=False):
60+
await asyncio.wait_for(
61+
redis_client.shutdown(),
62+
# NOTE: shutdown already has a _HEALTHCHECK_TASK_TIMEOUT_S of 10s
63+
timeout=20,
64+
)

packages/service-library/src/servicelib/rabbitmq/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from models_library.rabbitmq_basic_types import RPCNamespace
22

33
from ._client import RabbitMQClient
4-
from ._client_rpc import RabbitMQRPCClient
4+
from ._client_rpc import RabbitMQRPCClient, rabbitmq_rpc_client_context
55
from ._constants import BIND_TO_ALL_TOPICS, RPC_REQUEST_DEFAULT_TIMEOUT_S
66
from ._errors import (
77
RemoteMethodNotRegisteredError,
@@ -28,6 +28,7 @@
2828
"RabbitMQRPCClient",
2929
"RemoteMethodNotRegisteredError",
3030
"is_rabbitmq_responsive",
31+
"rabbitmq_rpc_client_context",
3132
"wait_till_rabbitmq_responsive",
3233
)
3334

packages/service-library/src/servicelib/rabbitmq/_client_rpc.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import asyncio
22
import functools
33
import logging
4-
from collections.abc import Callable
4+
from collections.abc import AsyncIterator, Callable
5+
from contextlib import asynccontextmanager
56
from dataclasses import dataclass
67
from typing import Any
78

@@ -156,3 +157,19 @@ async def unregister_handler(self, handler: Callable[..., Any]) -> None:
156157
raise RPCNotInitializedError
157158

158159
await self._rpc.unregister(handler)
160+
161+
162+
@asynccontextmanager
163+
async def rabbitmq_rpc_client_context(
164+
rpc_client_name: str, settings: RabbitSettings, **kwargs
165+
) -> AsyncIterator[RabbitMQRPCClient]:
166+
"""
167+
Adapter to create and close a RabbitMQRPCClient using an async context manager.
168+
"""
169+
rpc_client = await RabbitMQRPCClient.create(
170+
client_name=rpc_client_name, settings=settings, **kwargs
171+
)
172+
try:
173+
yield rpc_client
174+
finally:
175+
await rpc_client.close()

packages/service-library/tests/fastapi/test_lifespan_utils.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,12 @@
1616
from pytest_mock import MockerFixture
1717
from pytest_simcore.helpers.logging_tools import log_context
1818
from servicelib.fastapi.lifespan_utils import (
19+
LifespanAlreadyCalledError,
20+
LifespanExpectedCalledError,
1921
LifespanOnShutdownError,
2022
LifespanOnStartupError,
23+
ensure_lifespan_called,
24+
mark_lifespace_called,
2125
)
2226

2327

@@ -186,7 +190,7 @@ async def lifespan_failing_on_startup(app: FastAPI) -> AsyncIterator[State]:
186190
startup_step(_name)
187191
except RuntimeError as exc:
188192
handle_error(_name, exc)
189-
raise LifespanOnStartupError(module=_name) from exc
193+
raise LifespanOnStartupError(lifespan_name=_name) from exc
190194
yield {}
191195
shutdown_step(_name)
192196

@@ -201,7 +205,7 @@ async def lifespan_failing_on_shutdown(app: FastAPI) -> AsyncIterator[State]:
201205
shutdown_step(_name)
202206
except RuntimeError as exc:
203207
handle_error(_name, exc)
204-
raise LifespanOnShutdownError(module=_name) from exc
208+
raise LifespanOnShutdownError(lifespan_name=_name) from exc
205209

206210
return {
207211
"startup_step": startup_step,
@@ -228,7 +232,7 @@ async def test_app_lifespan_with_error_on_startup(
228232
assert not failing_lifespan_manager["startup_step"].called
229233
assert not failing_lifespan_manager["shutdown_step"].called
230234
assert exception.error_context() == {
231-
"module": "lifespan_failing_on_startup",
235+
"lifespan_name": "lifespan_failing_on_startup",
232236
"message": "Failed during startup of lifespan_failing_on_startup",
233237
"code": "RuntimeError.LifespanError.LifespanOnStartupError",
234238
}
@@ -250,7 +254,38 @@ async def test_app_lifespan_with_error_on_shutdown(
250254
assert failing_lifespan_manager["startup_step"].called
251255
assert not failing_lifespan_manager["shutdown_step"].called
252256
assert exception.error_context() == {
253-
"module": "lifespan_failing_on_shutdown",
257+
"lifespan_name": "lifespan_failing_on_shutdown",
254258
"message": "Failed during shutdown of lifespan_failing_on_shutdown",
255259
"code": "RuntimeError.LifespanError.LifespanOnShutdownError",
256260
}
261+
262+
263+
async def test_lifespan_called_more_than_once(is_pdb_enabled: bool):
264+
app_lifespan = LifespanManager()
265+
266+
@app_lifespan.add
267+
async def _one(_, state: State) -> AsyncIterator[State]:
268+
called_state = mark_lifespace_called(state, "test_lifespan_one")
269+
yield {"other": 0, **called_state}
270+
271+
@app_lifespan.add
272+
async def _two(_, state: State) -> AsyncIterator[State]:
273+
ensure_lifespan_called(state, "test_lifespan_one")
274+
275+
with pytest.raises(LifespanExpectedCalledError):
276+
ensure_lifespan_called(state, "test_lifespan_three")
277+
278+
called_state = mark_lifespace_called(state, "test_lifespan_two")
279+
yield {"something": 0, **called_state}
280+
281+
app_lifespan.add(_one) # added "by mistake"
282+
283+
with pytest.raises(LifespanAlreadyCalledError) as err_info:
284+
async with ASGILifespanManager(
285+
FastAPI(lifespan=app_lifespan),
286+
startup_timeout=None if is_pdb_enabled else 10,
287+
shutdown_timeout=None if is_pdb_enabled else 10,
288+
):
289+
...
290+
291+
assert err_info.value.lifespan_name == "test_lifespan_one"

0 commit comments

Comments
 (0)