Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add lifespan context manager support #203

Merged
merged 6 commits into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 60 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,41 @@ The [Model Context Protocol (MCP)](https://modelcontextprotocol.io) lets you bui
The FastMCP server is your core interface to the MCP protocol. It handles connection management, protocol compliance, and message routing:

```python
# Add lifespan support for startup/shutdown with strong typing
from dataclasses import dataclass
from typing import AsyncIterator
from mcp.server.fastmcp import FastMCP

# Create a named server
mcp = FastMCP("My App")

# Specify dependencies for deployment and development
mcp = FastMCP("My App", dependencies=["pandas", "numpy"])

@dataclass
class AppContext:
db: Database # Replace with your actual DB type

@asynccontextmanager
async def app_lifespan(server: FastMCP) -> AsyncIterator[AppContext]:
"""Manage application lifecycle with type-safe context"""
try:
# Initialize on startup
await db.connect()
yield AppContext(db=db)
finally:
# Cleanup on shutdown
await db.disconnect()

# Pass lifespan to server
mcp = FastMCP("My App", lifespan=app_lifespan)

# Access type-safe lifespan context in tools
@mcp.tool()
def query_db(ctx: Context) -> str:
"""Tool that uses initialized resources"""
db = ctx.request_context.lifespan_context["db"]
return db.query()
```

### Resources
Expand Down Expand Up @@ -334,7 +362,38 @@ def query_data(sql: str) -> str:

### Low-Level Server

For more control, you can use the low-level server implementation directly. This gives you full access to the protocol and allows you to customize every aspect of your server:
For more control, you can use the low-level server implementation directly. This gives you full access to the protocol and allows you to customize every aspect of your server, including lifecycle management through the lifespan API:

```python
from contextlib import asynccontextmanager
from typing import AsyncIterator

@asynccontextmanager
async def server_lifespan(server: Server) -> AsyncIterator[dict]:
"""Manage server startup and shutdown lifecycle."""
try:
# Initialize resources on startup
await db.connect()
yield {"db": db}
finally:
# Clean up on shutdown
await db.disconnect()

# Pass lifespan to server
server = Server("example-server", lifespan=server_lifespan)

# Access lifespan context in handlers
@server.call_tool()
async def query_db(name: str, arguments: dict) -> list:
ctx = server.request_context
db = ctx.lifespan_context["db"]
return await db.query(arguments["query"])
```

The lifespan API provides:
- A way to initialize resources when the server starts and clean them up when it stops
- Access to initialized resources through the request context in handlers
- Type-safe context passing between lifespan and request handlers

```python
from mcp.server.lowlevel import Server, NotificationOptions
Expand Down
44 changes: 40 additions & 4 deletions src/mcp/server/fastmcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,13 @@
import inspect
import json
import re
from collections.abc import AsyncIterator
from contextlib import (
AbstractAsyncContextManager,
asynccontextmanager,
)
from itertools import chain
from typing import Any, Callable, Literal, Sequence
from typing import Any, Callable, Generic, Literal, Sequence

import anyio
import pydantic_core
Expand All @@ -19,8 +24,16 @@
from mcp.server.fastmcp.tools import ToolManager
from mcp.server.fastmcp.utilities.logging import configure_logging, get_logger
from mcp.server.fastmcp.utilities.types import Image
from mcp.server.lowlevel import Server as MCPServer
from mcp.server.lowlevel.helper_types import ReadResourceContents
from mcp.server.lowlevel.server import (
LifespanResultT,
)
from mcp.server.lowlevel.server import (
Server as MCPServer,
)
from mcp.server.lowlevel.server import (
lifespan as default_lifespan,
)
from mcp.server.sse import SseServerTransport
from mcp.server.stdio import stdio_server
from mcp.shared.context import RequestContext
Expand Down Expand Up @@ -50,7 +63,7 @@
logger = get_logger(__name__)


class Settings(BaseSettings):
class Settings(BaseSettings, Generic[LifespanResultT]):
"""FastMCP server settings.

All settings can be configured via environment variables with the prefix FASTMCP_.
Expand Down Expand Up @@ -85,13 +98,36 @@ class Settings(BaseSettings):
description="List of dependencies to install in the server environment",
)

lifespan: (
Callable[["FastMCP"], AbstractAsyncContextManager[LifespanResultT]] | None
) = Field(None, description="Lifespan context manager")


def lifespan_wrapper(
app: "FastMCP",
lifespan: Callable[["FastMCP"], AbstractAsyncContextManager[LifespanResultT]],
) -> Callable[[MCPServer], AbstractAsyncContextManager[object]]:
@asynccontextmanager
async def wrap(s: MCPServer) -> AsyncIterator[object]:
async with lifespan(app) as context:
yield context

return wrap


class FastMCP:
def __init__(
self, name: str | None = None, instructions: str | None = None, **settings: Any
):
self.settings = Settings(**settings)
self._mcp_server = MCPServer(name=name or "FastMCP", instructions=instructions)

self._mcp_server = MCPServer(
name=name or "FastMCP",
instructions=instructions,
lifespan=lifespan_wrapper(self, self.settings.lifespan)
if self.settings.lifespan
else default_lifespan,
)
self._tool_manager = ToolManager(
warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools
)
Expand Down
57 changes: 47 additions & 10 deletions src/mcp/server/lowlevel/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ async def main():
import logging
import warnings
from collections.abc import Awaitable, Callable
from typing import Any, Sequence
from contextlib import AbstractAsyncContextManager, asynccontextmanager
from typing import Any, AsyncIterator, Generic, Sequence, TypeVar

from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import AnyUrl
Expand All @@ -84,7 +85,10 @@ async def main():

logger = logging.getLogger(__name__)

request_ctx: contextvars.ContextVar[RequestContext[ServerSession]] = (
LifespanResultT = TypeVar("LifespanResultT")

# This will be properly typed in each Server instance's context
request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any]] = (
contextvars.ContextVar("request_ctx")
)

Expand All @@ -101,13 +105,33 @@ def __init__(
self.tools_changed = tools_changed


class Server:
@asynccontextmanager
async def lifespan(server: "Server") -> AsyncIterator[object]:
"""Default lifespan context manager that does nothing.

Args:
server: The server instance this lifespan is managing

Returns:
An empty context object
"""
yield {}


class Server(Generic[LifespanResultT]):
def __init__(
self, name: str, version: str | None = None, instructions: str | None = None
self,
name: str,
version: str | None = None,
instructions: str | None = None,
lifespan: Callable[
["Server"], AbstractAsyncContextManager[LifespanResultT]
] = lifespan,
):
self.name = name
self.version = version
self.instructions = instructions
self.lifespan = lifespan
self.request_handlers: dict[
type, Callable[..., Awaitable[types.ServerResult]]
] = {
Expand Down Expand Up @@ -188,7 +212,7 @@ def get_capabilities(
)

@property
def request_context(self) -> RequestContext[ServerSession]:
def request_context(self) -> RequestContext[ServerSession, LifespanResultT]:
"""If called outside of a request context, this will raise a LookupError."""
return request_ctx.get()

Expand Down Expand Up @@ -446,9 +470,14 @@ async def run(
raise_exceptions: bool = False,
):
with warnings.catch_warnings(record=True) as w:
async with ServerSession(
read_stream, write_stream, initialization_options
) as session:
from contextlib import AsyncExitStack

async with AsyncExitStack() as stack:
lifespan_context = await stack.enter_async_context(self.lifespan(self))
session = await stack.enter_async_context(
ServerSession(read_stream, write_stream, initialization_options)
)
Comment on lines +475 to +479
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this different from doing async with lifespan_context, session?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The AsyncExitStack() version provides more granular control over multiple context managers. If an error occurs during the second enter_async_context, the first one will still be properly cleaned up.

Chatted more with Claude about it, we came to the. conclusion that AsyncExitStack is better here.


async for message in session.incoming_messages:
logger.debug(f"Received message: {message}")

Expand All @@ -460,21 +489,28 @@ async def run(
):
with responder:
await self._handle_request(
message, req, session, raise_exceptions
message,
req,
session,
lifespan_context,
raise_exceptions,
)
case types.ClientNotification(root=notify):
await self._handle_notification(notify)

for warning in w:
logger.info(
f"Warning: {warning.category.__name__}: {warning.message}"
"Warning: %s: %s",
warning.category.__name__,
warning.message,
)

async def _handle_request(
self,
message: RequestResponder,
req: Any,
session: ServerSession,
lifespan_context: LifespanResultT,
raise_exceptions: bool,
):
logger.info(f"Processing request of type {type(req).__name__}")
Expand All @@ -491,6 +527,7 @@ async def _handle_request(
message.request_id,
message.request_meta,
session,
lifespan_context,
)
)
response = await handler(req)
Expand Down
4 changes: 3 additions & 1 deletion src/mcp/shared/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
from mcp.types import RequestId, RequestParams

SessionT = TypeVar("SessionT", bound=BaseSession)
LifespanContextT = TypeVar("LifespanContextT")


@dataclass
class RequestContext(Generic[SessionT]):
class RequestContext(Generic[SessionT, LifespanContextT]):
request_id: RequestId
meta: RequestParams.Meta | None
session: SessionT
lifespan_context: LifespanContextT
5 changes: 4 additions & 1 deletion tests/issues/test_176_progress_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ async def test_progress_token_zero_first_call():
mock_meta.progressToken = 0 # This is the key test case - token is 0

request_context = RequestContext(
request_id="test-request", session=mock_session, meta=mock_meta
request_id="test-request",
session=mock_session,
meta=mock_meta,
lifespan_context=None,
)

# Create context with our mocks
Expand Down
22 changes: 11 additions & 11 deletions tests/server/fastmcp/test_func_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ async def check_call(args):

def test_complex_function_json_schema():
"""Test JSON schema generation for complex function arguments.

Note: Different versions of pydantic output slightly different
JSON Schema formats for model fields with defaults. The format changed in 2.9.0:

Expand All @@ -245,34 +245,34 @@ def test_complex_function_json_schema():
"allOf": [{"$ref": "#/$defs/Model"}],
"default": {}
}

2. Since 2.9.0:
{
"$ref": "#/$defs/Model",
"default": {}
}

Both formats are valid and functionally equivalent. This test accepts either format
to ensure compatibility across our supported pydantic versions.

This change in format does not affect runtime behavior since:
1. Both schemas validate the same way
2. The actual model classes and validation logic are unchanged
3. func_metadata uses model_validate/model_dump, not the schema directly
"""
meta = func_metadata(complex_arguments_fn)
actual_schema = meta.arg_model.model_json_schema()

# Create a copy of the actual schema to normalize
normalized_schema = actual_schema.copy()

# Normalize the my_model_a_with_default field to handle both pydantic formats
if 'allOf' in actual_schema['properties']['my_model_a_with_default']:
normalized_schema['properties']['my_model_a_with_default'] = {
'$ref': '#/$defs/SomeInputModelA',
'default': {}
if "allOf" in actual_schema["properties"]["my_model_a_with_default"]:
normalized_schema["properties"]["my_model_a_with_default"] = {
"$ref": "#/$defs/SomeInputModelA",
"default": {},
}

assert normalized_schema == {
"$defs": {
"InnerModel": {
Expand Down
Loading