Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for raw request injection in RequestContext. #380

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/mcp/client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@
class SamplingFnT(Protocol):
async def __call__(
self,
context: RequestContext["ClientSession", Any],
context: RequestContext["ClientSession", Any, Any],
params: types.CreateMessageRequestParams,
) -> types.CreateMessageResult | types.ErrorData: ...


class ListRootsFnT(Protocol):
async def __call__(
self, context: RequestContext["ClientSession", Any]
self, context: RequestContext["ClientSession", Any, Any]
) -> types.ListRootsResult | types.ErrorData: ...


Expand Down Expand Up @@ -50,7 +50,7 @@ async def _default_message_handler(


async def _default_sampling_callback(
context: RequestContext["ClientSession", Any],
context: RequestContext["ClientSession", Any, Any],
params: types.CreateMessageRequestParams,
) -> types.CreateMessageResult | types.ErrorData:
return types.ErrorData(
Expand All @@ -60,7 +60,7 @@ async def _default_sampling_callback(


async def _default_list_roots_callback(
context: RequestContext["ClientSession", Any],
context: RequestContext["ClientSession", Any, Any],
) -> types.ListRootsResult | types.ErrorData:
return types.ErrorData(
code=types.INVALID_REQUEST,
Expand Down Expand Up @@ -331,7 +331,7 @@ async def send_roots_list_changed(self) -> None:
async def _received_request(
self, responder: RequestResponder[types.ServerRequest, types.ClientResult]
) -> None:
ctx = RequestContext[ClientSession, Any](
ctx = RequestContext[ClientSession, Any, Any](
request_id=responder.request_id,
meta=responder.request_meta,
session=self,
Expand Down
16 changes: 11 additions & 5 deletions src/mcp/server/fastmcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,11 @@ class Settings(BaseSettings, Generic[LifespanResultT]):
def lifespan_wrapper(
app: FastMCP,
lifespan: Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]],
) -> Callable[[MCPServer[LifespanResultT]], AbstractAsyncContextManager[object]]:
) -> Callable[
[MCPServer[LifespanResultT, Request]], AbstractAsyncContextManager[object]
]:
@asynccontextmanager
async def wrap(s: MCPServer[LifespanResultT]) -> AsyncIterator[object]:
async def wrap(s: MCPServer[LifespanResultT, Request]) -> AsyncIterator[object]:
async with lifespan(app) as context:
yield context

Expand Down Expand Up @@ -491,6 +493,7 @@ async def handle_sse(request: Request) -> None:
streams[0],
streams[1],
self._mcp_server.create_initialization_options(),
request=request,
)

return Starlette(
Expand Down Expand Up @@ -592,13 +595,14 @@ def my_tool(x: int, ctx: Context) -> str:
The context is optional - tools that don't need it can omit the parameter.
"""

_request_context: RequestContext[ServerSessionT, LifespanContextT] | None
_request_context: RequestContext[ServerSessionT, LifespanContextT, Request] | None
_fastmcp: FastMCP | None

def __init__(
self,
*,
request_context: RequestContext[ServerSessionT, LifespanContextT] | None = None,
request_context: RequestContext[ServerSessionT, LifespanContextT, Request]
| None = None,
fastmcp: FastMCP | None = None,
**kwargs: Any,
):
Expand All @@ -614,7 +618,9 @@ def fastmcp(self) -> FastMCP:
return self._fastmcp

@property
def request_context(self) -> RequestContext[ServerSessionT, LifespanContextT]:
def request_context(
self,
) -> RequestContext[ServerSessionT, LifespanContextT, Request]:
"""Access to the underlying request context."""
if self._request_context is None:
raise ValueError("Context is not available outside of a request")
Expand Down
27 changes: 20 additions & 7 deletions src/mcp/server/lowlevel/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ async def main():
from mcp.server.models import InitializationOptions
from mcp.server.session import ServerSession
from mcp.server.stdio import stdio_server as stdio_server
from mcp.shared.context import RequestContext
from mcp.shared.context import RequestContext, RequestT
from mcp.shared.exceptions import McpError
from mcp.shared.session import RequestResponder

Expand All @@ -91,7 +91,7 @@ async def main():
LifespanResultT = TypeVar("LifespanResultT")

# This will be properly typed in each Server instance's context
request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any]] = (
request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any, Any]] = (
contextvars.ContextVar("request_ctx")
)

Expand All @@ -109,7 +109,7 @@ def __init__(


@asynccontextmanager
async def lifespan(server: Server[LifespanResultT]) -> AsyncIterator[object]:
async def lifespan(server: Server[LifespanResultT, RequestT]) -> AsyncIterator[object]:
"""Default lifespan context manager that does nothing.

Args:
Expand All @@ -121,14 +121,15 @@ async def lifespan(server: Server[LifespanResultT]) -> AsyncIterator[object]:
yield {}


class Server(Generic[LifespanResultT]):
class Server(Generic[LifespanResultT, RequestT]):
def __init__(
self,
name: str,
version: str | None = None,
instructions: str | None = None,
lifespan: Callable[
[Server[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]
[Server[LifespanResultT, RequestT]],
AbstractAsyncContextManager[LifespanResultT],
] = lifespan,
):
self.name = name
Expand Down Expand Up @@ -213,7 +214,9 @@ def get_capabilities(
)

@property
def request_context(self) -> RequestContext[ServerSession, LifespanResultT]:
def request_context(
self,
) -> RequestContext[ServerSession, LifespanResultT, RequestT]:
"""If called outside of a request context, this will raise a LookupError."""
return request_ctx.get()

Expand Down Expand Up @@ -479,6 +482,7 @@ async def run(
# but also make tracing exceptions much easier during testing and when using
# in-process servers.
raise_exceptions: bool = False,
request: RequestT | None = None,
):
async with AsyncExitStack() as stack:
lifespan_context = await stack.enter_async_context(self.lifespan(self))
Expand All @@ -496,6 +500,7 @@ async def run(
session,
lifespan_context,
raise_exceptions,
request,
)

async def _handle_message(
Expand All @@ -506,6 +511,7 @@ async def _handle_message(
session: ServerSession,
lifespan_context: LifespanResultT,
raise_exceptions: bool = False,
request: RequestT | None = None,
):
with warnings.catch_warnings(record=True) as w:
# TODO(Marcelo): We should be checking if message is Exception here.
Expand All @@ -515,7 +521,12 @@ async def _handle_message(
):
with responder:
await self._handle_request(
message, req, session, lifespan_context, raise_exceptions
message,
req,
session,
lifespan_context,
raise_exceptions,
request,
)
case types.ClientNotification(root=notify):
await self._handle_notification(notify)
Expand All @@ -530,6 +541,7 @@ async def _handle_request(
session: ServerSession,
lifespan_context: LifespanResultT,
raise_exceptions: bool,
request: RequestT | None,
):
logger.info(f"Processing request of type {type(req).__name__}")
if type(req) in self.request_handlers:
Expand All @@ -546,6 +558,7 @@ async def _handle_request(
message.request_meta,
session,
lifespan_context,
request=request,
)
)
response = await handler(req)
Expand Down
4 changes: 3 additions & 1 deletion src/mcp/shared/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@

SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any])
LifespanContextT = TypeVar("LifespanContextT")
RequestT = TypeVar("RequestT")


@dataclass
class RequestContext(Generic[SessionT, LifespanContextT]):
class RequestContext(Generic[SessionT, LifespanContextT, RequestT]):
request_id: RequestId
meta: RequestParams.Meta | None
session: SessionT
lifespan_context: LifespanContextT
request: RequestT | None = None
2 changes: 1 addition & 1 deletion src/mcp/shared/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ async def create_client_server_memory_streams() -> (

@asynccontextmanager
async def create_connected_server_and_client_session(
server: Server[Any],
server: Server[Any, Any],
read_timeout_seconds: timedelta | None = None,
sampling_callback: SamplingFnT | None = None,
list_roots_callback: ListRootsFnT | None = None,
Expand Down
3 changes: 2 additions & 1 deletion src/mcp/shared/progress.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections.abc import Generator
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import Generic
from typing import Any, Generic

from pydantic import BaseModel

Expand Down Expand Up @@ -62,6 +62,7 @@ def progress(
ReceiveNotificationT,
],
LifespanContextT,
Any,
],
total: float | None = None,
) -> Generator[
Expand Down
2 changes: 1 addition & 1 deletion tests/client/test_list_roots_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ async def test_list_roots_callback():
)

async def list_roots_callback(
context: RequestContext[ClientSession, None],
context: RequestContext[ClientSession, None, None],
) -> ListRootsResult:
return callback_return

Expand Down
2 changes: 1 addition & 1 deletion tests/client/test_sampling_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ async def test_sampling_callback():
)

async def sampling_callback(
context: RequestContext[ClientSession, None],
context: RequestContext[ClientSession, None, None],
params: CreateMessageRequestParams,
) -> CreateMessageResult:
return callback_return
Expand Down