Skip to content

Commit 7c47d1f

Browse files
authored
Merge pull request #203 from modelcontextprotocol/davidsp/clean-lifespan
feat: add lifespan context manager support
2 parents f10665d + 4d3e05f commit 7c47d1f

File tree

7 files changed

+372
-28
lines changed

7 files changed

+372
-28
lines changed

README.md

+60-1
Original file line numberDiff line numberDiff line change
@@ -128,13 +128,41 @@ The [Model Context Protocol (MCP)](https://modelcontextprotocol.io) lets you bui
128128
The FastMCP server is your core interface to the MCP protocol. It handles connection management, protocol compliance, and message routing:
129129

130130
```python
131+
# Add lifespan support for startup/shutdown with strong typing
132+
from dataclasses import dataclass
133+
from typing import AsyncIterator
131134
from mcp.server.fastmcp import FastMCP
132135

133136
# Create a named server
134137
mcp = FastMCP("My App")
135138

136139
# Specify dependencies for deployment and development
137140
mcp = FastMCP("My App", dependencies=["pandas", "numpy"])
141+
142+
@dataclass
143+
class AppContext:
144+
db: Database # Replace with your actual DB type
145+
146+
@asynccontextmanager
147+
async def app_lifespan(server: FastMCP) -> AsyncIterator[AppContext]:
148+
"""Manage application lifecycle with type-safe context"""
149+
try:
150+
# Initialize on startup
151+
await db.connect()
152+
yield AppContext(db=db)
153+
finally:
154+
# Cleanup on shutdown
155+
await db.disconnect()
156+
157+
# Pass lifespan to server
158+
mcp = FastMCP("My App", lifespan=app_lifespan)
159+
160+
# Access type-safe lifespan context in tools
161+
@mcp.tool()
162+
def query_db(ctx: Context) -> str:
163+
"""Tool that uses initialized resources"""
164+
db = ctx.request_context.lifespan_context["db"]
165+
return db.query()
138166
```
139167

140168
### Resources
@@ -334,7 +362,38 @@ def query_data(sql: str) -> str:
334362

335363
### Low-Level Server
336364

337-
For more control, you can use the low-level server implementation directly. This gives you full access to the protocol and allows you to customize every aspect of your server:
365+
For more control, you can use the low-level server implementation directly. This gives you full access to the protocol and allows you to customize every aspect of your server, including lifecycle management through the lifespan API:
366+
367+
```python
368+
from contextlib import asynccontextmanager
369+
from typing import AsyncIterator
370+
371+
@asynccontextmanager
372+
async def server_lifespan(server: Server) -> AsyncIterator[dict]:
373+
"""Manage server startup and shutdown lifecycle."""
374+
try:
375+
# Initialize resources on startup
376+
await db.connect()
377+
yield {"db": db}
378+
finally:
379+
# Clean up on shutdown
380+
await db.disconnect()
381+
382+
# Pass lifespan to server
383+
server = Server("example-server", lifespan=server_lifespan)
384+
385+
# Access lifespan context in handlers
386+
@server.call_tool()
387+
async def query_db(name: str, arguments: dict) -> list:
388+
ctx = server.request_context
389+
db = ctx.lifespan_context["db"]
390+
return await db.query(arguments["query"])
391+
```
392+
393+
The lifespan API provides:
394+
- A way to initialize resources when the server starts and clean them up when it stops
395+
- Access to initialized resources through the request context in handlers
396+
- Type-safe context passing between lifespan and request handlers
338397

339398
```python
340399
from mcp.server.lowlevel import Server, NotificationOptions

src/mcp/server/fastmcp/server.py

+40-4
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,13 @@
33
import inspect
44
import json
55
import re
6+
from collections.abc import AsyncIterator
7+
from contextlib import (
8+
AbstractAsyncContextManager,
9+
asynccontextmanager,
10+
)
611
from itertools import chain
7-
from typing import Any, Callable, Literal, Sequence
12+
from typing import Any, Callable, Generic, Literal, Sequence
813

914
import anyio
1015
import pydantic_core
@@ -19,8 +24,16 @@
1924
from mcp.server.fastmcp.tools import ToolManager
2025
from mcp.server.fastmcp.utilities.logging import configure_logging, get_logger
2126
from mcp.server.fastmcp.utilities.types import Image
22-
from mcp.server.lowlevel import Server as MCPServer
2327
from mcp.server.lowlevel.helper_types import ReadResourceContents
28+
from mcp.server.lowlevel.server import (
29+
LifespanResultT,
30+
)
31+
from mcp.server.lowlevel.server import (
32+
Server as MCPServer,
33+
)
34+
from mcp.server.lowlevel.server import (
35+
lifespan as default_lifespan,
36+
)
2437
from mcp.server.sse import SseServerTransport
2538
from mcp.server.stdio import stdio_server
2639
from mcp.shared.context import RequestContext
@@ -50,7 +63,7 @@
5063
logger = get_logger(__name__)
5164

5265

53-
class Settings(BaseSettings):
66+
class Settings(BaseSettings, Generic[LifespanResultT]):
5467
"""FastMCP server settings.
5568
5669
All settings can be configured via environment variables with the prefix FASTMCP_.
@@ -85,13 +98,36 @@ class Settings(BaseSettings):
8598
description="List of dependencies to install in the server environment",
8699
)
87100

101+
lifespan: (
102+
Callable[["FastMCP"], AbstractAsyncContextManager[LifespanResultT]] | None
103+
) = Field(None, description="Lifespan context manager")
104+
105+
106+
def lifespan_wrapper(
107+
app: "FastMCP",
108+
lifespan: Callable[["FastMCP"], AbstractAsyncContextManager[LifespanResultT]],
109+
) -> Callable[[MCPServer], AbstractAsyncContextManager[object]]:
110+
@asynccontextmanager
111+
async def wrap(s: MCPServer) -> AsyncIterator[object]:
112+
async with lifespan(app) as context:
113+
yield context
114+
115+
return wrap
116+
88117

89118
class FastMCP:
90119
def __init__(
91120
self, name: str | None = None, instructions: str | None = None, **settings: Any
92121
):
93122
self.settings = Settings(**settings)
94-
self._mcp_server = MCPServer(name=name or "FastMCP", instructions=instructions)
123+
124+
self._mcp_server = MCPServer(
125+
name=name or "FastMCP",
126+
instructions=instructions,
127+
lifespan=lifespan_wrapper(self, self.settings.lifespan)
128+
if self.settings.lifespan
129+
else default_lifespan,
130+
)
95131
self._tool_manager = ToolManager(
96132
warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools
97133
)

src/mcp/server/lowlevel/server.py

+47-10
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ async def main():
6868
import logging
6969
import warnings
7070
from collections.abc import Awaitable, Callable
71-
from typing import Any, Sequence
71+
from contextlib import AbstractAsyncContextManager, asynccontextmanager
72+
from typing import Any, AsyncIterator, Generic, Sequence, TypeVar
7273

7374
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
7475
from pydantic import AnyUrl
@@ -84,7 +85,10 @@ async def main():
8485

8586
logger = logging.getLogger(__name__)
8687

87-
request_ctx: contextvars.ContextVar[RequestContext[ServerSession]] = (
88+
LifespanResultT = TypeVar("LifespanResultT")
89+
90+
# This will be properly typed in each Server instance's context
91+
request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any]] = (
8892
contextvars.ContextVar("request_ctx")
8993
)
9094

@@ -101,13 +105,33 @@ def __init__(
101105
self.tools_changed = tools_changed
102106

103107

104-
class Server:
108+
@asynccontextmanager
109+
async def lifespan(server: "Server") -> AsyncIterator[object]:
110+
"""Default lifespan context manager that does nothing.
111+
112+
Args:
113+
server: The server instance this lifespan is managing
114+
115+
Returns:
116+
An empty context object
117+
"""
118+
yield {}
119+
120+
121+
class Server(Generic[LifespanResultT]):
105122
def __init__(
106-
self, name: str, version: str | None = None, instructions: str | None = None
123+
self,
124+
name: str,
125+
version: str | None = None,
126+
instructions: str | None = None,
127+
lifespan: Callable[
128+
["Server"], AbstractAsyncContextManager[LifespanResultT]
129+
] = lifespan,
107130
):
108131
self.name = name
109132
self.version = version
110133
self.instructions = instructions
134+
self.lifespan = lifespan
111135
self.request_handlers: dict[
112136
type, Callable[..., Awaitable[types.ServerResult]]
113137
] = {
@@ -188,7 +212,7 @@ def get_capabilities(
188212
)
189213

190214
@property
191-
def request_context(self) -> RequestContext[ServerSession]:
215+
def request_context(self) -> RequestContext[ServerSession, LifespanResultT]:
192216
"""If called outside of a request context, this will raise a LookupError."""
193217
return request_ctx.get()
194218

@@ -446,9 +470,14 @@ async def run(
446470
raise_exceptions: bool = False,
447471
):
448472
with warnings.catch_warnings(record=True) as w:
449-
async with ServerSession(
450-
read_stream, write_stream, initialization_options
451-
) as session:
473+
from contextlib import AsyncExitStack
474+
475+
async with AsyncExitStack() as stack:
476+
lifespan_context = await stack.enter_async_context(self.lifespan(self))
477+
session = await stack.enter_async_context(
478+
ServerSession(read_stream, write_stream, initialization_options)
479+
)
480+
452481
async for message in session.incoming_messages:
453482
logger.debug(f"Received message: {message}")
454483

@@ -460,21 +489,28 @@ async def run(
460489
):
461490
with responder:
462491
await self._handle_request(
463-
message, req, session, raise_exceptions
492+
message,
493+
req,
494+
session,
495+
lifespan_context,
496+
raise_exceptions,
464497
)
465498
case types.ClientNotification(root=notify):
466499
await self._handle_notification(notify)
467500

468501
for warning in w:
469502
logger.info(
470-
f"Warning: {warning.category.__name__}: {warning.message}"
503+
"Warning: %s: %s",
504+
warning.category.__name__,
505+
warning.message,
471506
)
472507

473508
async def _handle_request(
474509
self,
475510
message: RequestResponder,
476511
req: Any,
477512
session: ServerSession,
513+
lifespan_context: LifespanResultT,
478514
raise_exceptions: bool,
479515
):
480516
logger.info(f"Processing request of type {type(req).__name__}")
@@ -491,6 +527,7 @@ async def _handle_request(
491527
message.request_id,
492528
message.request_meta,
493529
session,
530+
lifespan_context,
494531
)
495532
)
496533
response = await handler(req)

src/mcp/shared/context.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55
from mcp.types import RequestId, RequestParams
66

77
SessionT = TypeVar("SessionT", bound=BaseSession)
8+
LifespanContextT = TypeVar("LifespanContextT")
89

910

1011
@dataclass
11-
class RequestContext(Generic[SessionT]):
12+
class RequestContext(Generic[SessionT, LifespanContextT]):
1213
request_id: RequestId
1314
meta: RequestParams.Meta | None
1415
session: SessionT
16+
lifespan_context: LifespanContextT

tests/issues/test_176_progress_token.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@ async def test_progress_token_zero_first_call():
2020
mock_meta.progressToken = 0 # This is the key test case - token is 0
2121

2222
request_context = RequestContext(
23-
request_id="test-request", session=mock_session, meta=mock_meta
23+
request_id="test-request",
24+
session=mock_session,
25+
meta=mock_meta,
26+
lifespan_context=None,
2427
)
2528

2629
# Create context with our mocks

tests/server/fastmcp/test_func_metadata.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ async def check_call(args):
236236

237237
def test_complex_function_json_schema():
238238
"""Test JSON schema generation for complex function arguments.
239-
239+
240240
Note: Different versions of pydantic output slightly different
241241
JSON Schema formats for model fields with defaults. The format changed in 2.9.0:
242242
@@ -245,34 +245,34 @@ def test_complex_function_json_schema():
245245
"allOf": [{"$ref": "#/$defs/Model"}],
246246
"default": {}
247247
}
248-
248+
249249
2. Since 2.9.0:
250250
{
251251
"$ref": "#/$defs/Model",
252252
"default": {}
253253
}
254-
254+
255255
Both formats are valid and functionally equivalent. This test accepts either format
256256
to ensure compatibility across our supported pydantic versions.
257-
257+
258258
This change in format does not affect runtime behavior since:
259259
1. Both schemas validate the same way
260260
2. The actual model classes and validation logic are unchanged
261261
3. func_metadata uses model_validate/model_dump, not the schema directly
262262
"""
263263
meta = func_metadata(complex_arguments_fn)
264264
actual_schema = meta.arg_model.model_json_schema()
265-
265+
266266
# Create a copy of the actual schema to normalize
267267
normalized_schema = actual_schema.copy()
268-
268+
269269
# Normalize the my_model_a_with_default field to handle both pydantic formats
270-
if 'allOf' in actual_schema['properties']['my_model_a_with_default']:
271-
normalized_schema['properties']['my_model_a_with_default'] = {
272-
'$ref': '#/$defs/SomeInputModelA',
273-
'default': {}
270+
if "allOf" in actual_schema["properties"]["my_model_a_with_default"]:
271+
normalized_schema["properties"]["my_model_a_with_default"] = {
272+
"$ref": "#/$defs/SomeInputModelA",
273+
"default": {},
274274
}
275-
275+
276276
assert normalized_schema == {
277277
"$defs": {
278278
"InnerModel": {

0 commit comments

Comments
 (0)