diff --git a/README.md b/README.md index 310bb35b..370b4f33 100644 --- a/README.md +++ b/README.md @@ -128,6 +128,9 @@ The [Model Context Protocol (MCP)](https://modelcontextprotocol.io) lets you bui The FastMCP server is your core interface to the MCP protocol. It handles connection management, protocol compliance, and message routing: ```python +# Add lifespan support for startup/shutdown with strong typing +from dataclasses import dataclass +from typing import AsyncIterator from mcp.server.fastmcp import FastMCP # Create a named server @@ -135,6 +138,31 @@ mcp = FastMCP("My App") # Specify dependencies for deployment and development mcp = FastMCP("My App", dependencies=["pandas", "numpy"]) + +@dataclass +class AppContext: + db: Database # Replace with your actual DB type + +@asynccontextmanager +async def app_lifespan(server: FastMCP) -> AsyncIterator[AppContext]: + """Manage application lifecycle with type-safe context""" + try: + # Initialize on startup + await db.connect() + yield AppContext(db=db) + finally: + # Cleanup on shutdown + await db.disconnect() + +# Pass lifespan to server +mcp = FastMCP("My App", lifespan=app_lifespan) + +# Access type-safe lifespan context in tools +@mcp.tool() +def query_db(ctx: Context) -> str: + """Tool that uses initialized resources""" + db = ctx.request_context.lifespan_context["db"] + return db.query() ``` ### Resources @@ -334,7 +362,38 @@ def query_data(sql: str) -> str: ### Low-Level Server -For more control, you can use the low-level server implementation directly. This gives you full access to the protocol and allows you to customize every aspect of your server: +For more control, you can use the low-level server implementation directly. This gives you full access to the protocol and allows you to customize every aspect of your server, including lifecycle management through the lifespan API: + +```python +from contextlib import asynccontextmanager +from typing import AsyncIterator + +@asynccontextmanager +async def server_lifespan(server: Server) -> AsyncIterator[dict]: + """Manage server startup and shutdown lifecycle.""" + try: + # Initialize resources on startup + await db.connect() + yield {"db": db} + finally: + # Clean up on shutdown + await db.disconnect() + +# Pass lifespan to server +server = Server("example-server", lifespan=server_lifespan) + +# Access lifespan context in handlers +@server.call_tool() +async def query_db(name: str, arguments: dict) -> list: + ctx = server.request_context + db = ctx.lifespan_context["db"] + return await db.query(arguments["query"]) +``` + +The lifespan API provides: +- A way to initialize resources when the server starts and clean them up when it stops +- Access to initialized resources through the request context in handlers +- Type-safe context passing between lifespan and request handlers ```python from mcp.server.lowlevel import Server, NotificationOptions diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index aa7c79bc..5ae30a5c 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -3,8 +3,13 @@ import inspect import json import re +from collections.abc import AsyncIterator +from contextlib import ( + AbstractAsyncContextManager, + asynccontextmanager, +) from itertools import chain -from typing import Any, Callable, Literal, Sequence +from typing import Any, Callable, Generic, Literal, Sequence import anyio import pydantic_core @@ -19,8 +24,16 @@ from mcp.server.fastmcp.tools import ToolManager from mcp.server.fastmcp.utilities.logging import configure_logging, get_logger from mcp.server.fastmcp.utilities.types import Image -from mcp.server.lowlevel import Server as MCPServer from mcp.server.lowlevel.helper_types import ReadResourceContents +from mcp.server.lowlevel.server import ( + LifespanResultT, +) +from mcp.server.lowlevel.server import ( + Server as MCPServer, +) +from mcp.server.lowlevel.server import ( + lifespan as default_lifespan, +) from mcp.server.sse import SseServerTransport from mcp.server.stdio import stdio_server from mcp.shared.context import RequestContext @@ -50,7 +63,7 @@ logger = get_logger(__name__) -class Settings(BaseSettings): +class Settings(BaseSettings, Generic[LifespanResultT]): """FastMCP server settings. All settings can be configured via environment variables with the prefix FASTMCP_. @@ -85,13 +98,36 @@ class Settings(BaseSettings): description="List of dependencies to install in the server environment", ) + lifespan: ( + Callable[["FastMCP"], AbstractAsyncContextManager[LifespanResultT]] | None + ) = Field(None, description="Lifespan context manager") + + +def lifespan_wrapper( + app: "FastMCP", + lifespan: Callable[["FastMCP"], AbstractAsyncContextManager[LifespanResultT]], +) -> Callable[[MCPServer], AbstractAsyncContextManager[object]]: + @asynccontextmanager + async def wrap(s: MCPServer) -> AsyncIterator[object]: + async with lifespan(app) as context: + yield context + + return wrap + class FastMCP: def __init__( self, name: str | None = None, instructions: str | None = None, **settings: Any ): self.settings = Settings(**settings) - self._mcp_server = MCPServer(name=name or "FastMCP", instructions=instructions) + + self._mcp_server = MCPServer( + name=name or "FastMCP", + instructions=instructions, + lifespan=lifespan_wrapper(self, self.settings.lifespan) + if self.settings.lifespan + else default_lifespan, + ) self._tool_manager = ToolManager( warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools ) diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 3d917226..643e1a27 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -68,7 +68,8 @@ async def main(): import logging import warnings from collections.abc import Awaitable, Callable -from typing import Any, Sequence +from contextlib import AbstractAsyncContextManager, asynccontextmanager +from typing import Any, AsyncIterator, Generic, Sequence, TypeVar from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import AnyUrl @@ -84,7 +85,10 @@ async def main(): logger = logging.getLogger(__name__) -request_ctx: contextvars.ContextVar[RequestContext[ServerSession]] = ( +LifespanResultT = TypeVar("LifespanResultT") + +# This will be properly typed in each Server instance's context +request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any]] = ( contextvars.ContextVar("request_ctx") ) @@ -101,13 +105,33 @@ def __init__( self.tools_changed = tools_changed -class Server: +@asynccontextmanager +async def lifespan(server: "Server") -> AsyncIterator[object]: + """Default lifespan context manager that does nothing. + + Args: + server: The server instance this lifespan is managing + + Returns: + An empty context object + """ + yield {} + + +class Server(Generic[LifespanResultT]): def __init__( - self, name: str, version: str | None = None, instructions: str | None = None + self, + name: str, + version: str | None = None, + instructions: str | None = None, + lifespan: Callable[ + ["Server"], AbstractAsyncContextManager[LifespanResultT] + ] = lifespan, ): self.name = name self.version = version self.instructions = instructions + self.lifespan = lifespan self.request_handlers: dict[ type, Callable[..., Awaitable[types.ServerResult]] ] = { @@ -188,7 +212,7 @@ def get_capabilities( ) @property - def request_context(self) -> RequestContext[ServerSession]: + def request_context(self) -> RequestContext[ServerSession, LifespanResultT]: """If called outside of a request context, this will raise a LookupError.""" return request_ctx.get() @@ -446,9 +470,14 @@ async def run( raise_exceptions: bool = False, ): with warnings.catch_warnings(record=True) as w: - async with ServerSession( - read_stream, write_stream, initialization_options - ) as session: + from contextlib import AsyncExitStack + + async with AsyncExitStack() as stack: + lifespan_context = await stack.enter_async_context(self.lifespan(self)) + session = await stack.enter_async_context( + ServerSession(read_stream, write_stream, initialization_options) + ) + async for message in session.incoming_messages: logger.debug(f"Received message: {message}") @@ -460,14 +489,20 @@ async def run( ): with responder: await self._handle_request( - message, req, session, raise_exceptions + message, + req, + session, + lifespan_context, + raise_exceptions, ) case types.ClientNotification(root=notify): await self._handle_notification(notify) for warning in w: logger.info( - f"Warning: {warning.category.__name__}: {warning.message}" + "Warning: %s: %s", + warning.category.__name__, + warning.message, ) async def _handle_request( @@ -475,6 +510,7 @@ async def _handle_request( message: RequestResponder, req: Any, session: ServerSession, + lifespan_context: LifespanResultT, raise_exceptions: bool, ): logger.info(f"Processing request of type {type(req).__name__}") @@ -491,6 +527,7 @@ async def _handle_request( message.request_id, message.request_meta, session, + lifespan_context, ) ) response = await handler(req) diff --git a/src/mcp/shared/context.py b/src/mcp/shared/context.py index 760d5587..a45fdacd 100644 --- a/src/mcp/shared/context.py +++ b/src/mcp/shared/context.py @@ -5,10 +5,12 @@ from mcp.types import RequestId, RequestParams SessionT = TypeVar("SessionT", bound=BaseSession) +LifespanContextT = TypeVar("LifespanContextT") @dataclass -class RequestContext(Generic[SessionT]): +class RequestContext(Generic[SessionT, LifespanContextT]): request_id: RequestId meta: RequestParams.Meta | None session: SessionT + lifespan_context: LifespanContextT diff --git a/tests/issues/test_176_progress_token.py b/tests/issues/test_176_progress_token.py index ed8ab128..7f9131a1 100644 --- a/tests/issues/test_176_progress_token.py +++ b/tests/issues/test_176_progress_token.py @@ -20,7 +20,10 @@ async def test_progress_token_zero_first_call(): mock_meta.progressToken = 0 # This is the key test case - token is 0 request_context = RequestContext( - request_id="test-request", session=mock_session, meta=mock_meta + request_id="test-request", + session=mock_session, + meta=mock_meta, + lifespan_context=None, ) # Create context with our mocks diff --git a/tests/server/fastmcp/test_func_metadata.py b/tests/server/fastmcp/test_func_metadata.py index b68fb902..6461648e 100644 --- a/tests/server/fastmcp/test_func_metadata.py +++ b/tests/server/fastmcp/test_func_metadata.py @@ -236,7 +236,7 @@ async def check_call(args): def test_complex_function_json_schema(): """Test JSON schema generation for complex function arguments. - + Note: Different versions of pydantic output slightly different JSON Schema formats for model fields with defaults. The format changed in 2.9.0: @@ -245,16 +245,16 @@ def test_complex_function_json_schema(): "allOf": [{"$ref": "#/$defs/Model"}], "default": {} } - + 2. Since 2.9.0: { "$ref": "#/$defs/Model", "default": {} } - + Both formats are valid and functionally equivalent. This test accepts either format to ensure compatibility across our supported pydantic versions. - + This change in format does not affect runtime behavior since: 1. Both schemas validate the same way 2. The actual model classes and validation logic are unchanged @@ -262,17 +262,17 @@ def test_complex_function_json_schema(): """ meta = func_metadata(complex_arguments_fn) actual_schema = meta.arg_model.model_json_schema() - + # Create a copy of the actual schema to normalize normalized_schema = actual_schema.copy() - + # Normalize the my_model_a_with_default field to handle both pydantic formats - if 'allOf' in actual_schema['properties']['my_model_a_with_default']: - normalized_schema['properties']['my_model_a_with_default'] = { - '$ref': '#/$defs/SomeInputModelA', - 'default': {} + if "allOf" in actual_schema["properties"]["my_model_a_with_default"]: + normalized_schema["properties"]["my_model_a_with_default"] = { + "$ref": "#/$defs/SomeInputModelA", + "default": {}, } - + assert normalized_schema == { "$defs": { "InnerModel": { diff --git a/tests/server/test_lifespan.py b/tests/server/test_lifespan.py new file mode 100644 index 00000000..14afb6b0 --- /dev/null +++ b/tests/server/test_lifespan.py @@ -0,0 +1,207 @@ +"""Tests for lifespan functionality in both low-level and FastMCP servers.""" + +from contextlib import asynccontextmanager +from typing import AsyncIterator + +import anyio +import pytest +from pydantic import TypeAdapter + +from mcp.server.fastmcp import Context, FastMCP +from mcp.server.lowlevel.server import NotificationOptions, Server +from mcp.server.models import InitializationOptions +from mcp.types import ( + ClientCapabilities, + Implementation, + InitializeRequestParams, + JSONRPCMessage, + JSONRPCNotification, + JSONRPCRequest, +) + + +@pytest.mark.anyio +async def test_lowlevel_server_lifespan(): + """Test that lifespan works in low-level server.""" + + @asynccontextmanager + async def test_lifespan(server: Server) -> AsyncIterator[dict]: + """Test lifespan context that tracks startup/shutdown.""" + context = {"started": False, "shutdown": False} + try: + context["started"] = True + yield context + finally: + context["shutdown"] = True + + server = Server("test", lifespan=test_lifespan) + + # Create memory streams for testing + send_stream1, receive_stream1 = anyio.create_memory_object_stream(100) + send_stream2, receive_stream2 = anyio.create_memory_object_stream(100) + + # Create a tool that accesses lifespan context + @server.call_tool() + async def check_lifespan(name: str, arguments: dict) -> list: + ctx = server.request_context + assert isinstance(ctx.lifespan_context, dict) + assert ctx.lifespan_context["started"] + assert not ctx.lifespan_context["shutdown"] + return [{"type": "text", "text": "true"}] + + # Run server in background task + async with anyio.create_task_group() as tg: + + async def run_server(): + await server.run( + receive_stream1, + send_stream2, + InitializationOptions( + server_name="test", + server_version="0.1.0", + capabilities=server.get_capabilities( + notification_options=NotificationOptions(), + experimental_capabilities={}, + ), + ), + raise_exceptions=True, + ) + + tg.start_soon(run_server) + + # Initialize the server + params = InitializeRequestParams( + protocolVersion="2024-11-05", + capabilities=ClientCapabilities(), + clientInfo=Implementation(name="test-client", version="0.1.0"), + ) + await send_stream1.send( + JSONRPCMessage( + root=JSONRPCRequest( + jsonrpc="2.0", + id=1, + method="initialize", + params=TypeAdapter(InitializeRequestParams).dump_python(params), + ) + ) + ) + response = await receive_stream2.receive() + + # Send initialized notification + await send_stream1.send( + JSONRPCMessage( + root=JSONRPCNotification( + jsonrpc="2.0", + method="notifications/initialized", + ) + ) + ) + + # Call the tool to verify lifespan context + await send_stream1.send( + JSONRPCMessage( + root=JSONRPCRequest( + jsonrpc="2.0", + id=2, + method="tools/call", + params={"name": "check_lifespan", "arguments": {}}, + ) + ) + ) + + # Get response and verify + response = await receive_stream2.receive() + assert response.root.result["content"][0]["text"] == "true" + + # Cancel server task + tg.cancel_scope.cancel() + + +@pytest.mark.anyio +async def test_fastmcp_server_lifespan(): + """Test that lifespan works in FastMCP server.""" + + @asynccontextmanager + async def test_lifespan(server: FastMCP) -> AsyncIterator[dict]: + """Test lifespan context that tracks startup/shutdown.""" + context = {"started": False, "shutdown": False} + try: + context["started"] = True + yield context + finally: + context["shutdown"] = True + + server = FastMCP("test", lifespan=test_lifespan) + + # Create memory streams for testing + send_stream1, receive_stream1 = anyio.create_memory_object_stream(100) + send_stream2, receive_stream2 = anyio.create_memory_object_stream(100) + + # Add a tool that checks lifespan context + @server.tool() + def check_lifespan(ctx: Context) -> bool: + """Tool that checks lifespan context.""" + assert isinstance(ctx.request_context.lifespan_context, dict) + assert ctx.request_context.lifespan_context["started"] + assert not ctx.request_context.lifespan_context["shutdown"] + return True + + # Run server in background task + async with anyio.create_task_group() as tg: + + async def run_server(): + await server._mcp_server.run( + receive_stream1, + send_stream2, + server._mcp_server.create_initialization_options(), + raise_exceptions=True, + ) + + tg.start_soon(run_server) + + # Initialize the server + params = InitializeRequestParams( + protocolVersion="2024-11-05", + capabilities=ClientCapabilities(), + clientInfo=Implementation(name="test-client", version="0.1.0"), + ) + await send_stream1.send( + JSONRPCMessage( + root=JSONRPCRequest( + jsonrpc="2.0", + id=1, + method="initialize", + params=TypeAdapter(InitializeRequestParams).dump_python(params), + ) + ) + ) + response = await receive_stream2.receive() + + # Send initialized notification + await send_stream1.send( + JSONRPCMessage( + root=JSONRPCNotification( + jsonrpc="2.0", + method="notifications/initialized", + ) + ) + ) + + # Call the tool to verify lifespan context + await send_stream1.send( + JSONRPCMessage( + root=JSONRPCRequest( + jsonrpc="2.0", + id=2, + method="tools/call", + params={"name": "check_lifespan", "arguments": {}}, + ) + ) + ) + + # Get response and verify + response = await receive_stream2.receive() + assert response.root.result["content"][0]["text"] == "true" + + # Cancel server task + tg.cancel_scope.cancel()