Skip to content

Commit 5553f22

Browse files
author
Yassine Lassoued
committed
Added support for raw request injection in RequestContext.
1 parent 9a2bb6a commit 5553f22

17 files changed

+144
-24
lines changed

.idea/.gitignore

+8
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/deployment.xml

+56
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/inspectionProfiles/profiles_settings.xml

+6
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/misc.xml

+7
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/modules.xml

+8
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/python-sdk.iml

+18
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/vcs.xml

+6
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/mcp/client/session.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,14 @@
1414
class SamplingFnT(Protocol):
1515
async def __call__(
1616
self,
17-
context: RequestContext["ClientSession", Any],
17+
context: RequestContext["ClientSession", Any, Any],
1818
params: types.CreateMessageRequestParams,
1919
) -> types.CreateMessageResult | types.ErrorData: ...
2020

2121

2222
class ListRootsFnT(Protocol):
2323
async def __call__(
24-
self, context: RequestContext["ClientSession", Any]
24+
self, context: RequestContext["ClientSession", Any, Any]
2525
) -> types.ListRootsResult | types.ErrorData: ...
2626

2727

@@ -50,7 +50,7 @@ async def _default_message_handler(
5050

5151

5252
async def _default_sampling_callback(
53-
context: RequestContext["ClientSession", Any],
53+
context: RequestContext["ClientSession", Any, Any],
5454
params: types.CreateMessageRequestParams,
5555
) -> types.CreateMessageResult | types.ErrorData:
5656
return types.ErrorData(
@@ -60,7 +60,7 @@ async def _default_sampling_callback(
6060

6161

6262
async def _default_list_roots_callback(
63-
context: RequestContext["ClientSession", Any],
63+
context: RequestContext["ClientSession", Any, Any],
6464
) -> types.ListRootsResult | types.ErrorData:
6565
return types.ErrorData(
6666
code=types.INVALID_REQUEST,
@@ -331,7 +331,7 @@ async def send_roots_list_changed(self) -> None:
331331
async def _received_request(
332332
self, responder: RequestResponder[types.ServerRequest, types.ClientResult]
333333
) -> None:
334-
ctx = RequestContext[ClientSession, Any](
334+
ctx = RequestContext[ClientSession, Any, Any](
335335
request_id=responder.request_id,
336336
meta=responder.request_meta,
337337
session=self,

src/mcp/server/fastmcp/server.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from mcp.server.session import ServerSession, ServerSessionT
3737
from mcp.server.sse import SseServerTransport
3838
from mcp.server.stdio import stdio_server
39-
from mcp.shared.context import LifespanContextT, RequestContext
39+
from mcp.shared.context import LifespanContextT, RequestContext, RequestT
4040
from mcp.types import (
4141
AnyFunction,
4242
EmbeddedResource,
@@ -182,7 +182,7 @@ async def list_tools(self) -> list[MCPTool]:
182182
for info in tools
183183
]
184184

185-
def get_context(self) -> Context[ServerSession, object]:
185+
def get_context(self) -> Context[ServerSession, object, Request]:
186186
"""
187187
Returns a Context object. Note that the context will only be valid
188188
during a request; outside a request, most methods will error.
@@ -491,6 +491,7 @@ async def handle_sse(request: Request) -> None:
491491
streams[0],
492492
streams[1],
493493
self._mcp_server.create_initialization_options(),
494+
request=request,
494495
)
495496

496497
return Starlette(
@@ -558,7 +559,7 @@ def _convert_to_content(
558559
return [TextContent(type="text", text=result)]
559560

560561

561-
class Context(BaseModel, Generic[ServerSessionT, LifespanContextT]):
562+
class Context(BaseModel, Generic[ServerSessionT, LifespanContextT, RequestT]):
562563
"""Context object providing access to MCP capabilities.
563564
564565
This provides a cleaner interface to MCP's RequestContext functionality.
@@ -592,13 +593,13 @@ def my_tool(x: int, ctx: Context) -> str:
592593
The context is optional - tools that don't need it can omit the parameter.
593594
"""
594595

595-
_request_context: RequestContext[ServerSessionT, LifespanContextT] | None
596+
_request_context: RequestContext[ServerSessionT, LifespanContextT, RequestT] | None
596597
_fastmcp: FastMCP | None
597598

598599
def __init__(
599600
self,
600601
*,
601-
request_context: RequestContext[ServerSessionT, LifespanContextT] | None = None,
602+
request_context: RequestContext[ServerSessionT, LifespanContextT, RequestT] | None = None,
602603
fastmcp: FastMCP | None = None,
603604
**kwargs: Any,
604605
):

src/mcp/server/fastmcp/tools/base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def from_function(
7676
async def run(
7777
self,
7878
arguments: dict[str, Any],
79-
context: Context[ServerSessionT, LifespanContextT] | None = None,
79+
context: Context[ServerSessionT, LifespanContextT, Any] | None = None,
8080
) -> Any:
8181
"""Run the tool with arguments."""
8282
try:

src/mcp/server/fastmcp/tools/tool_manager.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ async def call_tool(
5050
self,
5151
name: str,
5252
arguments: dict[str, Any],
53-
context: Context[ServerSessionT, LifespanContextT] | None = None,
53+
context: Context[ServerSessionT, LifespanContextT, Any] | None = None,
5454
) -> Any:
5555
"""Call a tool by name with arguments."""
5656
tool = self.get_tool(name)

src/mcp/server/lowlevel/server.py

+12-7
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ async def main():
8282
from mcp.server.models import InitializationOptions
8383
from mcp.server.session import ServerSession
8484
from mcp.server.stdio import stdio_server as stdio_server
85-
from mcp.shared.context import RequestContext
85+
from mcp.shared.context import RequestContext, RequestT
8686
from mcp.shared.exceptions import McpError
8787
from mcp.shared.session import RequestResponder
8888

@@ -91,7 +91,7 @@ async def main():
9191
LifespanResultT = TypeVar("LifespanResultT")
9292

9393
# This will be properly typed in each Server instance's context
94-
request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any]] = (
94+
request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any, RequestT]] = (
9595
contextvars.ContextVar("request_ctx")
9696
)
9797

@@ -109,7 +109,7 @@ def __init__(
109109

110110

111111
@asynccontextmanager
112-
async def lifespan(server: Server[LifespanResultT]) -> AsyncIterator[object]:
112+
async def lifespan(server: Server[LifespanResultT, RequestT]) -> AsyncIterator[object]:
113113
"""Default lifespan context manager that does nothing.
114114
115115
Args:
@@ -121,14 +121,14 @@ async def lifespan(server: Server[LifespanResultT]) -> AsyncIterator[object]:
121121
yield {}
122122

123123

124-
class Server(Generic[LifespanResultT]):
124+
class Server(Generic[LifespanResultT, RequestT]):
125125
def __init__(
126126
self,
127127
name: str,
128128
version: str | None = None,
129129
instructions: str | None = None,
130130
lifespan: Callable[
131-
[Server[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]
131+
[Server[LifespanResultT, RequestT]], AbstractAsyncContextManager[LifespanResultT]
132132
] = lifespan,
133133
):
134134
self.name = name
@@ -213,7 +213,7 @@ def get_capabilities(
213213
)
214214

215215
@property
216-
def request_context(self) -> RequestContext[ServerSession, LifespanResultT]:
216+
def request_context(self) -> RequestContext[ServerSession, LifespanResultT, RequestT]:
217217
"""If called outside of a request context, this will raise a LookupError."""
218218
return request_ctx.get()
219219

@@ -479,6 +479,7 @@ async def run(
479479
# but also make tracing exceptions much easier during testing and when using
480480
# in-process servers.
481481
raise_exceptions: bool = False,
482+
request: RequestT | None = None,
482483
):
483484
async with AsyncExitStack() as stack:
484485
lifespan_context = await stack.enter_async_context(self.lifespan(self))
@@ -496,6 +497,7 @@ async def run(
496497
session,
497498
lifespan_context,
498499
raise_exceptions,
500+
request,
499501
)
500502

501503
async def _handle_message(
@@ -506,6 +508,7 @@ async def _handle_message(
506508
session: ServerSession,
507509
lifespan_context: LifespanResultT,
508510
raise_exceptions: bool = False,
511+
request: RequestT | None = None,
509512
):
510513
with warnings.catch_warnings(record=True) as w:
511514
# TODO(Marcelo): We should be checking if message is Exception here.
@@ -515,7 +518,7 @@ async def _handle_message(
515518
):
516519
with responder:
517520
await self._handle_request(
518-
message, req, session, lifespan_context, raise_exceptions
521+
message, req, session, lifespan_context, raise_exceptions, request
519522
)
520523
case types.ClientNotification(root=notify):
521524
await self._handle_notification(notify)
@@ -530,6 +533,7 @@ async def _handle_request(
530533
session: ServerSession,
531534
lifespan_context: LifespanResultT,
532535
raise_exceptions: bool,
536+
request: RequestT | None,
533537
):
534538
logger.info(f"Processing request of type {type(req).__name__}")
535539
if type(req) in self.request_handlers:
@@ -546,6 +550,7 @@ async def _handle_request(
546550
message.request_meta,
547551
session,
548552
lifespan_context,
553+
request=request,
549554
)
550555
)
551556
response = await handler(req)

src/mcp/shared/context.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,13 @@
88

99
SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any])
1010
LifespanContextT = TypeVar("LifespanContextT")
11+
RequestT = TypeVar("RequestT")
1112

1213

1314
@dataclass
14-
class RequestContext(Generic[SessionT, LifespanContextT]):
15+
class RequestContext(Generic[SessionT, LifespanContextT, RequestT]):
1516
request_id: RequestId
1617
meta: RequestParams.Meta | None
1718
session: SessionT
1819
lifespan_context: LifespanContextT
20+
request: RequestT | None = None

src/mcp/shared/memory.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ async def create_client_server_memory_streams() -> (
5959

6060
@asynccontextmanager
6161
async def create_connected_server_and_client_session(
62-
server: Server[Any],
62+
server: Server[Any, Any],
6363
read_timeout_seconds: timedelta | None = None,
6464
sampling_callback: SamplingFnT | None = None,
6565
list_roots_callback: ListRootsFnT | None = None,

src/mcp/shared/progress.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from collections.abc import Generator
22
from contextlib import contextmanager
33
from dataclasses import dataclass, field
4-
from typing import Generic
4+
from typing import Generic, Any
55

66
from pydantic import BaseModel
77

@@ -62,6 +62,7 @@ def progress(
6262
ReceiveNotificationT,
6363
],
6464
LifespanContextT,
65+
Any,
6566
],
6667
total: float | None = None,
6768
) -> Generator[

tests/client/test_list_roots_callback.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from types import NoneType
2+
13
import pytest
24
from pydantic import FileUrl
35

@@ -30,7 +32,7 @@ async def test_list_roots_callback():
3032
)
3133

3234
async def list_roots_callback(
33-
context: RequestContext[ClientSession, None],
35+
context: RequestContext[ClientSession, None, None],
3436
) -> ListRootsResult:
3537
return callback_return
3638

0 commit comments

Comments
 (0)