diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 65d5e11e..155e7ecb 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -14,14 +14,14 @@ class SamplingFnT(Protocol): async def __call__( self, - context: RequestContext["ClientSession", Any], + context: RequestContext["ClientSession", Any, Any], params: types.CreateMessageRequestParams, ) -> types.CreateMessageResult | types.ErrorData: ... class ListRootsFnT(Protocol): async def __call__( - self, context: RequestContext["ClientSession", Any] + self, context: RequestContext["ClientSession", Any, Any] ) -> types.ListRootsResult | types.ErrorData: ... @@ -50,7 +50,7 @@ async def _default_message_handler( async def _default_sampling_callback( - context: RequestContext["ClientSession", Any], + context: RequestContext["ClientSession", Any, Any], params: types.CreateMessageRequestParams, ) -> types.CreateMessageResult | types.ErrorData: return types.ErrorData( @@ -60,7 +60,7 @@ async def _default_sampling_callback( async def _default_list_roots_callback( - context: RequestContext["ClientSession", Any], + context: RequestContext["ClientSession", Any, Any], ) -> types.ListRootsResult | types.ErrorData: return types.ErrorData( code=types.INVALID_REQUEST, @@ -331,7 +331,7 @@ async def send_roots_list_changed(self) -> None: async def _received_request( self, responder: RequestResponder[types.ServerRequest, types.ClientResult] ) -> None: - ctx = RequestContext[ClientSession, Any]( + ctx = RequestContext[ClientSession, Any, Any]( request_id=responder.request_id, meta=responder.request_meta, session=self, diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index bf0ce880..22978eb2 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -98,9 +98,11 @@ class Settings(BaseSettings, Generic[LifespanResultT]): def lifespan_wrapper( app: FastMCP, lifespan: Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]], -) -> Callable[[MCPServer[LifespanResultT]], AbstractAsyncContextManager[object]]: +) -> Callable[ + [MCPServer[LifespanResultT, Request]], AbstractAsyncContextManager[object] +]: @asynccontextmanager - async def wrap(s: MCPServer[LifespanResultT]) -> AsyncIterator[object]: + async def wrap(s: MCPServer[LifespanResultT, Request]) -> AsyncIterator[object]: async with lifespan(app) as context: yield context @@ -491,6 +493,7 @@ async def handle_sse(request: Request) -> None: streams[0], streams[1], self._mcp_server.create_initialization_options(), + request=request, ) return Starlette( @@ -592,13 +595,14 @@ def my_tool(x: int, ctx: Context) -> str: The context is optional - tools that don't need it can omit the parameter. """ - _request_context: RequestContext[ServerSessionT, LifespanContextT] | None + _request_context: RequestContext[ServerSessionT, LifespanContextT, Request] | None _fastmcp: FastMCP | None def __init__( self, *, - request_context: RequestContext[ServerSessionT, LifespanContextT] | None = None, + request_context: RequestContext[ServerSessionT, LifespanContextT, Request] + | None = None, fastmcp: FastMCP | None = None, **kwargs: Any, ): @@ -614,7 +618,9 @@ def fastmcp(self) -> FastMCP: return self._fastmcp @property - def request_context(self) -> RequestContext[ServerSessionT, LifespanContextT]: + def request_context( + self, + ) -> RequestContext[ServerSessionT, LifespanContextT, Request]: """Access to the underlying request context.""" if self._request_context is None: raise ValueError("Context is not available outside of a request") diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index dbaff305..3356cd99 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -82,7 +82,7 @@ async def main(): from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession from mcp.server.stdio import stdio_server as stdio_server -from mcp.shared.context import RequestContext +from mcp.shared.context import RequestContext, RequestT from mcp.shared.exceptions import McpError from mcp.shared.session import RequestResponder @@ -91,7 +91,7 @@ async def main(): LifespanResultT = TypeVar("LifespanResultT") # This will be properly typed in each Server instance's context -request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any]] = ( +request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any, Any]] = ( contextvars.ContextVar("request_ctx") ) @@ -109,7 +109,7 @@ def __init__( @asynccontextmanager -async def lifespan(server: Server[LifespanResultT]) -> AsyncIterator[object]: +async def lifespan(server: Server[LifespanResultT, RequestT]) -> AsyncIterator[object]: """Default lifespan context manager that does nothing. Args: @@ -121,14 +121,15 @@ async def lifespan(server: Server[LifespanResultT]) -> AsyncIterator[object]: yield {} -class Server(Generic[LifespanResultT]): +class Server(Generic[LifespanResultT, RequestT]): def __init__( self, name: str, version: str | None = None, instructions: str | None = None, lifespan: Callable[ - [Server[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT] + [Server[LifespanResultT, RequestT]], + AbstractAsyncContextManager[LifespanResultT], ] = lifespan, ): self.name = name @@ -213,7 +214,9 @@ def get_capabilities( ) @property - def request_context(self) -> RequestContext[ServerSession, LifespanResultT]: + def request_context( + self, + ) -> RequestContext[ServerSession, LifespanResultT, RequestT]: """If called outside of a request context, this will raise a LookupError.""" return request_ctx.get() @@ -479,6 +482,7 @@ async def run( # but also make tracing exceptions much easier during testing and when using # in-process servers. raise_exceptions: bool = False, + request: RequestT | None = None, ): async with AsyncExitStack() as stack: lifespan_context = await stack.enter_async_context(self.lifespan(self)) @@ -496,6 +500,7 @@ async def run( session, lifespan_context, raise_exceptions, + request, ) async def _handle_message( @@ -506,6 +511,7 @@ async def _handle_message( session: ServerSession, lifespan_context: LifespanResultT, raise_exceptions: bool = False, + request: RequestT | None = None, ): with warnings.catch_warnings(record=True) as w: # TODO(Marcelo): We should be checking if message is Exception here. @@ -515,7 +521,12 @@ async def _handle_message( ): with responder: await self._handle_request( - message, req, session, lifespan_context, raise_exceptions + message, + req, + session, + lifespan_context, + raise_exceptions, + request, ) case types.ClientNotification(root=notify): await self._handle_notification(notify) @@ -530,6 +541,7 @@ async def _handle_request( session: ServerSession, lifespan_context: LifespanResultT, raise_exceptions: bool, + request: RequestT | None, ): logger.info(f"Processing request of type {type(req).__name__}") if type(req) in self.request_handlers: @@ -546,6 +558,7 @@ async def _handle_request( message.request_meta, session, lifespan_context, + request=request, ) ) response = await handler(req) diff --git a/src/mcp/shared/context.py b/src/mcp/shared/context.py index ae85d3a1..2b673565 100644 --- a/src/mcp/shared/context.py +++ b/src/mcp/shared/context.py @@ -8,11 +8,13 @@ SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any]) LifespanContextT = TypeVar("LifespanContextT") +RequestT = TypeVar("RequestT") @dataclass -class RequestContext(Generic[SessionT, LifespanContextT]): +class RequestContext(Generic[SessionT, LifespanContextT, RequestT]): request_id: RequestId meta: RequestParams.Meta | None session: SessionT lifespan_context: LifespanContextT + request: RequestT | None = None diff --git a/src/mcp/shared/memory.py b/src/mcp/shared/memory.py index 346f6156..5ebdebdf 100644 --- a/src/mcp/shared/memory.py +++ b/src/mcp/shared/memory.py @@ -59,7 +59,7 @@ async def create_client_server_memory_streams() -> ( @asynccontextmanager async def create_connected_server_and_client_session( - server: Server[Any], + server: Server[Any, Any], read_timeout_seconds: timedelta | None = None, sampling_callback: SamplingFnT | None = None, list_roots_callback: ListRootsFnT | None = None, diff --git a/src/mcp/shared/progress.py b/src/mcp/shared/progress.py index 52e0017d..6e734869 100644 --- a/src/mcp/shared/progress.py +++ b/src/mcp/shared/progress.py @@ -1,7 +1,7 @@ from collections.abc import Generator from contextlib import contextmanager from dataclasses import dataclass, field -from typing import Generic +from typing import Any, Generic from pydantic import BaseModel @@ -62,6 +62,7 @@ def progress( ReceiveNotificationT, ], LifespanContextT, + Any, ], total: float | None = None, ) -> Generator[ diff --git a/tests/client/test_list_roots_callback.py b/tests/client/test_list_roots_callback.py index f5b59821..6f3d3576 100644 --- a/tests/client/test_list_roots_callback.py +++ b/tests/client/test_list_roots_callback.py @@ -30,7 +30,7 @@ async def test_list_roots_callback(): ) async def list_roots_callback( - context: RequestContext[ClientSession, None], + context: RequestContext[ClientSession, None, None], ) -> ListRootsResult: return callback_return diff --git a/tests/client/test_sampling_callback.py b/tests/client/test_sampling_callback.py index ba586d4a..ee09cc08 100644 --- a/tests/client/test_sampling_callback.py +++ b/tests/client/test_sampling_callback.py @@ -29,7 +29,7 @@ async def test_sampling_callback(): ) async def sampling_callback( - context: RequestContext[ClientSession, None], + context: RequestContext[ClientSession, None, None], params: CreateMessageRequestParams, ) -> CreateMessageResult: return callback_return