diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 122acebb..ae3434be 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -37,7 +37,7 @@ from mcp.server.session import ServerSession from mcp.server.sse import SseServerTransport from mcp.server.stdio import stdio_server -from mcp.shared.context import RequestContext +from mcp.shared.context import LifespanContextT, RequestContext from mcp.types import ( AnyFunction, EmbeddedResource, @@ -564,7 +564,7 @@ def _convert_to_content( return [TextContent(type="text", text=result)] -class Context(BaseModel): +class Context(BaseModel, Generic[LifespanContextT]): """Context object providing access to MCP capabilities. This provides a cleaner interface to MCP's RequestContext functionality. @@ -598,13 +598,13 @@ def my_tool(x: int, ctx: Context) -> str: The context is optional - tools that don't need it can omit the parameter. """ - _request_context: RequestContext[ServerSession, Any] | None + _request_context: RequestContext[ServerSession, LifespanContextT] | None _fastmcp: FastMCP | None def __init__( self, *, - request_context: RequestContext | None = None, + request_context: RequestContext[ServerSession, LifespanContextT] | None = None, fastmcp: FastMCP | None = None, **kwargs: Any, ): @@ -620,7 +620,7 @@ def fastmcp(self) -> FastMCP: return self._fastmcp @property - def request_context(self) -> RequestContext: + def request_context(self) -> RequestContext[ServerSession, LifespanContextT]: """Access to the underlying request context.""" if self._request_context is None: raise ValueError("Context is not available outside of a request") diff --git a/src/mcp/shared/context.py b/src/mcp/shared/context.py index a45fdacd..63759ca4 100644 --- a/src/mcp/shared/context.py +++ b/src/mcp/shared/context.py @@ -1,11 +1,13 @@ from dataclasses import dataclass -from typing import Generic, TypeVar +from typing import Any, Generic + +from typing_extensions import TypeVar from mcp.shared.session import BaseSession from mcp.types import RequestId, RequestParams -SessionT = TypeVar("SessionT", bound=BaseSession) -LifespanContextT = TypeVar("LifespanContextT") +SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any]) +LifespanContextT = TypeVar("LifespanContextT", default=None) @dataclass