diff --git a/README.md b/README.md index 370b4f33..bdbc9bca 100644 --- a/README.md +++ b/README.md @@ -476,9 +476,21 @@ server_params = StdioServerParameters( env=None # Optional environment variables ) +# Optional: create a sampling callback +async def handle_sampling_message(message: types.CreateMessageRequestParams) -> types.CreateMessageResult: + return types.CreateMessageResult( + role="assistant", + content=types.TextContent( + type="text", + text="Hello, world! from model", + ), + model="gpt-3.5-turbo", + stopReason="endTurn", + ) + async def run(): async with stdio_client(server_params) as (read, write): - async with ClientSession(read, write) as session: + async with ClientSession(read, write, sampling_callback=handle_sampling_message) as session: # Initialize the connection await session.initialize() diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 4858ede5..37036e2b 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -1,13 +1,51 @@ from datetime import timedelta +from typing import Any, Protocol from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream -from pydantic import AnyUrl +from pydantic import AnyUrl, TypeAdapter import mcp.types as types -from mcp.shared.session import BaseSession +from mcp.shared.context import RequestContext +from mcp.shared.session import BaseSession, RequestResponder from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS +class SamplingFnT(Protocol): + async def __call__( + self, + context: RequestContext["ClientSession", Any], + params: types.CreateMessageRequestParams, + ) -> types.CreateMessageResult | types.ErrorData: ... + + +class ListRootsFnT(Protocol): + async def __call__( + self, context: RequestContext["ClientSession", Any] + ) -> types.ListRootsResult | types.ErrorData: ... + + +async def _default_sampling_callback( + context: RequestContext["ClientSession", Any], + params: types.CreateMessageRequestParams, +) -> types.CreateMessageResult | types.ErrorData: + return types.ErrorData( + code=types.INVALID_REQUEST, + message="Sampling not supported", + ) + + +async def _default_list_roots_callback( + context: RequestContext["ClientSession", Any], +) -> types.ListRootsResult | types.ErrorData: + return types.ErrorData( + code=types.INVALID_REQUEST, + message="List roots not supported", + ) + + +ClientResponse = TypeAdapter(types.ClientResult | types.ErrorData) + + class ClientSession( BaseSession[ types.ClientRequest, @@ -22,6 +60,8 @@ def __init__( read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception], write_stream: MemoryObjectSendStream[types.JSONRPCMessage], read_timeout_seconds: timedelta | None = None, + sampling_callback: SamplingFnT | None = None, + list_roots_callback: ListRootsFnT | None = None, ) -> None: super().__init__( read_stream, @@ -30,8 +70,24 @@ def __init__( types.ServerNotification, read_timeout_seconds=read_timeout_seconds, ) + self._sampling_callback = sampling_callback or _default_sampling_callback + self._list_roots_callback = list_roots_callback or _default_list_roots_callback async def initialize(self) -> types.InitializeResult: + sampling = ( + types.SamplingCapability() if self._sampling_callback is not None else None + ) + roots = ( + types.RootsCapability( + # TODO: Should this be based on whether we + # _will_ send notifications, or only whether + # they're supported? + listChanged=True, + ) + if self._list_roots_callback is not None + else None + ) + result = await self.send_request( types.ClientRequest( types.InitializeRequest( @@ -39,14 +95,9 @@ async def initialize(self) -> types.InitializeResult: params=types.InitializeRequestParams( protocolVersion=types.LATEST_PROTOCOL_VERSION, capabilities=types.ClientCapabilities( - sampling=None, + sampling=sampling, experimental=None, - roots=types.RootsCapability( - # TODO: Should this be based on whether we - # _will_ send notifications, or only whether - # they're supported? - listChanged=True - ), + roots=roots, ), clientInfo=types.Implementation(name="mcp", version="0.1.0"), ), @@ -243,3 +294,32 @@ async def send_roots_list_changed(self) -> None: ) ) ) + + async def _received_request( + self, responder: RequestResponder[types.ServerRequest, types.ClientResult] + ) -> None: + ctx = RequestContext[ClientSession, Any]( + request_id=responder.request_id, + meta=responder.request_meta, + session=self, + lifespan_context=None, + ) + + match responder.request.root: + case types.CreateMessageRequest(params=params): + with responder: + response = await self._sampling_callback(ctx, params) + client_response = ClientResponse.validate_python(response) + await responder.respond(client_response) + + case types.ListRootsRequest(): + with responder: + response = await self._list_roots_callback(ctx) + client_response = ClientResponse.validate_python(response) + await responder.respond(client_response) + + case types.PingRequest(): + with responder: + return await responder.respond( + types.ClientResult(root=types.EmptyResult()) + ) diff --git a/src/mcp/shared/memory.py b/src/mcp/shared/memory.py index 72549925..ae6b0be5 100644 --- a/src/mcp/shared/memory.py +++ b/src/mcp/shared/memory.py @@ -9,7 +9,7 @@ import anyio from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream -from mcp.client.session import ClientSession +from mcp.client.session import ClientSession, ListRootsFnT, SamplingFnT from mcp.server import Server from mcp.types import JSONRPCMessage @@ -54,6 +54,8 @@ async def create_client_server_memory_streams() -> ( async def create_connected_server_and_client_session( server: Server, read_timeout_seconds: timedelta | None = None, + sampling_callback: SamplingFnT | None = None, + list_roots_callback: ListRootsFnT | None = None, raise_exceptions: bool = False, ) -> AsyncGenerator[ClientSession, None]: """Creates a ClientSession that is connected to a running MCP server.""" @@ -80,6 +82,8 @@ async def create_connected_server_and_client_session( read_stream=client_read, write_stream=client_write, read_timeout_seconds=read_timeout_seconds, + sampling_callback=sampling_callback, + list_roots_callback=list_roots_callback, ) as client_session: await client_session.initialize() yield client_session diff --git a/tests/client/test_list_roots_callback.py b/tests/client/test_list_roots_callback.py new file mode 100644 index 00000000..384e7676 --- /dev/null +++ b/tests/client/test_list_roots_callback.py @@ -0,0 +1,70 @@ +import pytest +from pydantic import FileUrl + +from mcp.client.session import ClientSession +from mcp.server.fastmcp.server import Context +from mcp.shared.context import RequestContext +from mcp.shared.memory import ( + create_connected_server_and_client_session as create_session, +) +from mcp.types import ( + ListRootsResult, + Root, + TextContent, +) + + +@pytest.mark.anyio +async def test_list_roots_callback(): + from mcp.server.fastmcp import FastMCP + + server = FastMCP("test") + + callback_return = ListRootsResult( + roots=[ + Root( + uri=FileUrl("file://users/fake/test"), + name="Test Root 1", + ), + Root( + uri=FileUrl("file://users/fake/test/2"), + name="Test Root 2", + ), + ] + ) + + async def list_roots_callback( + context: RequestContext[ClientSession, None], + ) -> ListRootsResult: + return callback_return + + @server.tool("test_list_roots") + async def test_list_roots(context: Context, message: str): + roots = await context.session.list_roots() + assert roots == callback_return + return True + + # Test with list_roots callback + async with create_session( + server._mcp_server, list_roots_callback=list_roots_callback + ) as client_session: + # Make a request to trigger sampling callback + result = await client_session.call_tool( + "test_list_roots", {"message": "test message"} + ) + assert result.isError is False + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "true" + + # Test without list_roots callback + async with create_session(server._mcp_server) as client_session: + # Make a request to trigger sampling callback + result = await client_session.call_tool( + "test_list_roots", {"message": "test message"} + ) + assert result.isError is True + assert isinstance(result.content[0], TextContent) + assert ( + result.content[0].text + == "Error executing tool test_list_roots: List roots not supported" + ) diff --git a/tests/client/test_sampling_callback.py b/tests/client/test_sampling_callback.py new file mode 100644 index 00000000..ba586d4a --- /dev/null +++ b/tests/client/test_sampling_callback.py @@ -0,0 +1,73 @@ +import pytest + +from mcp.client.session import ClientSession +from mcp.shared.context import RequestContext +from mcp.shared.memory import ( + create_connected_server_and_client_session as create_session, +) +from mcp.types import ( + CreateMessageRequestParams, + CreateMessageResult, + SamplingMessage, + TextContent, +) + + +@pytest.mark.anyio +async def test_sampling_callback(): + from mcp.server.fastmcp import FastMCP + + server = FastMCP("test") + + callback_return = CreateMessageResult( + role="assistant", + content=TextContent( + type="text", text="This is a response from the sampling callback" + ), + model="test-model", + stopReason="endTurn", + ) + + async def sampling_callback( + context: RequestContext[ClientSession, None], + params: CreateMessageRequestParams, + ) -> CreateMessageResult: + return callback_return + + @server.tool("test_sampling") + async def test_sampling_tool(message: str): + value = await server.get_context().session.create_message( + messages=[ + SamplingMessage( + role="user", content=TextContent(type="text", text=message) + ) + ], + max_tokens=100, + ) + assert value == callback_return + return True + + # Test with sampling callback + async with create_session( + server._mcp_server, sampling_callback=sampling_callback + ) as client_session: + # Make a request to trigger sampling callback + result = await client_session.call_tool( + "test_sampling", {"message": "Test message for sampling"} + ) + assert result.isError is False + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "true" + + # Test without sampling callback + async with create_session(server._mcp_server) as client_session: + # Make a request to trigger sampling callback + result = await client_session.call_tool( + "test_sampling", {"message": "Test message for sampling"} + ) + assert result.isError is True + assert isinstance(result.content[0], TextContent) + assert ( + result.content[0].text + == "Error executing tool test_sampling: Sampling not supported" + ) diff --git a/tests/client/test_stdio.py b/tests/client/test_stdio.py index 0bdec72d..95747ffd 100644 --- a/tests/client/test_stdio.py +++ b/tests/client/test_stdio.py @@ -1,12 +1,17 @@ +import shutil + import pytest from mcp.client.stdio import StdioServerParameters, stdio_client from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse +tee: str = shutil.which("tee") # type: ignore + @pytest.mark.anyio +@pytest.mark.skipif(tee is None, reason="could not find tee command") async def test_stdio_client(): - server_parameters = StdioServerParameters(command="/usr/bin/tee") + server_parameters = StdioServerParameters(command=tee) async with stdio_client(server_parameters) as (read_stream, write_stream): # Test sending and receiving messages