diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index ae3434be..1f5736e4 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -1,5 +1,7 @@ """FastMCP - A more ergonomic interface for MCP servers.""" +from __future__ import annotations as _annotations + import inspect import json import re @@ -25,16 +27,10 @@ from mcp.server.fastmcp.utilities.logging import configure_logging, get_logger from mcp.server.fastmcp.utilities.types import Image from mcp.server.lowlevel.helper_types import ReadResourceContents -from mcp.server.lowlevel.server import ( - LifespanResultT, -) -from mcp.server.lowlevel.server import ( - Server as MCPServer, -) -from mcp.server.lowlevel.server import ( - lifespan as default_lifespan, -) -from mcp.server.session import ServerSession +from mcp.server.lowlevel.server import LifespanResultT +from mcp.server.lowlevel.server import Server as MCPServer +from mcp.server.lowlevel.server import lifespan as default_lifespan +from mcp.server.session import ServerSession, ServerSessionT from mcp.server.sse import SseServerTransport from mcp.server.stdio import stdio_server from mcp.shared.context import LifespanContextT, RequestContext @@ -45,21 +41,11 @@ ImageContent, TextContent, ) -from mcp.types import ( - Prompt as MCPPrompt, -) -from mcp.types import ( - PromptArgument as MCPPromptArgument, -) -from mcp.types import ( - Resource as MCPResource, -) -from mcp.types import ( - ResourceTemplate as MCPResourceTemplate, -) -from mcp.types import ( - Tool as MCPTool, -) +from mcp.types import Prompt as MCPPrompt +from mcp.types import PromptArgument as MCPPromptArgument +from mcp.types import Resource as MCPResource +from mcp.types import ResourceTemplate as MCPResourceTemplate +from mcp.types import Tool as MCPTool logger = get_logger(__name__) @@ -105,11 +91,11 @@ class Settings(BaseSettings, Generic[LifespanResultT]): def lifespan_wrapper( - app: "FastMCP", + app: FastMCP, lifespan: Callable[["FastMCP"], AbstractAsyncContextManager[LifespanResultT]], -) -> Callable[[MCPServer], AbstractAsyncContextManager[object]]: +) -> Callable[[MCPServer[LifespanResultT]], AbstractAsyncContextManager[object]]: @asynccontextmanager - async def wrap(s: MCPServer) -> AsyncIterator[object]: + async def wrap(s: MCPServer[LifespanResultT]) -> AsyncIterator[object]: async with lifespan(app) as context: yield context @@ -191,7 +177,7 @@ async def list_tools(self) -> list[MCPTool]: for info in tools ] - def get_context(self) -> "Context": + def get_context(self) -> "Context[ServerSession, object]": """ Returns a Context object. Note that the context will only be valid during a request; outside a request, most methods will error. @@ -564,7 +550,7 @@ def _convert_to_content( return [TextContent(type="text", text=result)] -class Context(BaseModel, Generic[LifespanContextT]): +class Context(BaseModel, Generic[ServerSessionT, LifespanContextT]): """Context object providing access to MCP capabilities. This provides a cleaner interface to MCP's RequestContext functionality. @@ -598,13 +584,13 @@ def my_tool(x: int, ctx: Context) -> str: The context is optional - tools that don't need it can omit the parameter. """ - _request_context: RequestContext[ServerSession, LifespanContextT] | None + _request_context: RequestContext[ServerSessionT, LifespanContextT] | None _fastmcp: FastMCP | None def __init__( self, *, - request_context: RequestContext[ServerSession, LifespanContextT] | None = None, + request_context: RequestContext[ServerSessionT, LifespanContextT] | None = None, fastmcp: FastMCP | None = None, **kwargs: Any, ): @@ -620,7 +606,7 @@ def fastmcp(self) -> FastMCP: return self._fastmcp @property - def request_context(self) -> RequestContext[ServerSession, LifespanContextT]: + def request_context(self) -> RequestContext[ServerSessionT, LifespanContextT]: """Access to the underlying request context.""" if self._request_context is None: raise ValueError("Context is not available outside of a request") diff --git a/src/mcp/server/fastmcp/tools/base.py b/src/mcp/server/fastmcp/tools/base.py index a8751a5f..da5d9348 100644 --- a/src/mcp/server/fastmcp/tools/base.py +++ b/src/mcp/server/fastmcp/tools/base.py @@ -1,3 +1,5 @@ +from __future__ import annotations as _annotations + import inspect from typing import TYPE_CHECKING, Any, Callable @@ -9,6 +11,8 @@ if TYPE_CHECKING: from mcp.server.fastmcp.server import Context + from mcp.server.session import ServerSessionT + from mcp.shared.context import LifespanContextT class Tool(BaseModel): @@ -68,7 +72,11 @@ def from_function( context_kwarg=context_kwarg, ) - async def run(self, arguments: dict, context: "Context | None" = None) -> Any: + async def run( + self, + arguments: dict[str, Any], + context: Context[ServerSessionT, LifespanContextT] | None = None, + ) -> Any: """Run the tool with arguments.""" try: return await self.fn_metadata.call_fn_with_arg_validation( diff --git a/src/mcp/server/fastmcp/tools/tool_manager.py b/src/mcp/server/fastmcp/tools/tool_manager.py index 807c26b0..9a8bba8d 100644 --- a/src/mcp/server/fastmcp/tools/tool_manager.py +++ b/src/mcp/server/fastmcp/tools/tool_manager.py @@ -1,12 +1,16 @@ +from __future__ import annotations as _annotations + from collections.abc import Callable from typing import TYPE_CHECKING, Any from mcp.server.fastmcp.exceptions import ToolError from mcp.server.fastmcp.tools.base import Tool from mcp.server.fastmcp.utilities.logging import get_logger +from mcp.shared.context import LifespanContextT if TYPE_CHECKING: from mcp.server.fastmcp.server import Context + from mcp.server.session import ServerSessionT logger = get_logger(__name__) @@ -43,7 +47,10 @@ def add_tool( return tool async def call_tool( - self, name: str, arguments: dict, context: "Context | None" = None + self, + name: str, + arguments: dict[str, Any], + context: Context[ServerSessionT, LifespanContextT] | None = None, ) -> Any: """Call a tool by name with arguments.""" tool = self.get_tool(name) diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 25e94365..817d1918 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -64,6 +64,8 @@ async def main(): messages from the client. """ +from __future__ import annotations as _annotations + import contextvars import logging import warnings @@ -107,7 +109,7 @@ def __init__( @asynccontextmanager -async def lifespan(server: "Server") -> AsyncIterator[object]: +async def lifespan(server: Server[LifespanResultT]) -> AsyncIterator[object]: """Default lifespan context manager that does nothing. Args: @@ -126,7 +128,7 @@ def __init__( version: str | None = None, instructions: str | None = None, lifespan: Callable[ - ["Server"], AbstractAsyncContextManager[LifespanResultT] + [Server[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT] ] = lifespan, ): self.name = name diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index d918b988..788bb9f8 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -38,7 +38,7 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]: """ from enum import Enum -from typing import Any +from typing import Any, TypeVar import anyio import anyio.lowlevel @@ -59,6 +59,9 @@ class InitializationState(Enum): Initialized = 3 +ServerSessionT = TypeVar("ServerSessionT", bound="ServerSession") + + class ServerSession( BaseSession[ types.ServerRequest,