diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index fe90716e2..3b7fc3fae 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -116,12 +116,18 @@ def __init__( self._message_handler = message_handler or _default_message_handler async def initialize(self) -> types.InitializeResult: - sampling = types.SamplingCapability() - roots = types.RootsCapability( + sampling = ( + types.SamplingCapability() + if self._sampling_callback is not _default_sampling_callback + else None + ) + roots = ( # TODO: Should this be based on whether we # _will_ send notifications, or only whether # they're supported? - listChanged=True, + types.RootsCapability(listChanged=True) + if self._list_roots_callback is not _default_list_roots_callback + else None ) result = await self.send_request( diff --git a/src/mcp/types.py b/src/mcp/types.py index 829eaa5a0..4f5af27b9 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -218,7 +218,7 @@ class RootsCapability(BaseModel): class SamplingCapability(BaseModel): - """Capability for logging operations.""" + """Capability for sampling operations.""" model_config = ConfigDict(extra="allow") diff --git a/tests/client/test_session.py b/tests/client/test_session.py index cad89f217..72b4413d2 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -1,8 +1,11 @@ +from typing import Any + import anyio import pytest import mcp.types as types from mcp.client.session import DEFAULT_CLIENT_INFO, ClientSession +from mcp.shared.context import RequestContext from mcp.shared.message import SessionMessage from mcp.shared.session import RequestResponder from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS @@ -380,3 +383,167 @@ async def mock_server(): # Should raise RuntimeError for unsupported version with pytest.raises(RuntimeError, match="Unsupported protocol version"): await session.initialize() + + +@pytest.mark.anyio +async def test_client_capabilities_default(): + """Test that client capabilities are properly set with default callbacks""" + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ + SessionMessage + ](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ + SessionMessage + ](1) + + received_capabilities = None + + async def mock_server(): + nonlocal received_capabilities + + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + request = ClientRequest.model_validate( + jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + assert isinstance(request.root, InitializeRequest) + received_capabilities = request.root.params.capabilities + + result = ServerResult( + InitializeResult( + protocolVersion=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + serverInfo=Implementation(name="mock-server", version="0.1.0"), + ) + ) + + async with server_to_client_send: + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=result.model_dump( + by_alias=True, mode="json", exclude_none=True + ), + ) + ) + ) + ) + # Receive initialized notification + await client_to_server_receive.receive() + + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + ) as session, + anyio.create_task_group() as tg, + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, + ): + tg.start_soon(mock_server) + await session.initialize() + + # Assert that capabilities are properly set with defaults + assert received_capabilities is not None + assert received_capabilities.sampling is None # No custom sampling callback + assert received_capabilities.roots is None # No custom list_roots callback + + +@pytest.mark.anyio +async def test_client_capabilities_with_custom_callbacks(): + """Test that client capabilities are properly set with custom callbacks""" + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ + SessionMessage + ](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ + SessionMessage + ](1) + + received_capabilities = None + + async def custom_sampling_callback( + context: RequestContext["ClientSession", Any], + params: types.CreateMessageRequestParams, + ) -> types.CreateMessageResult | types.ErrorData: + return types.CreateMessageResult( + role="assistant", + content=types.TextContent(type="text", text="test"), + model="test-model", + ) + + async def custom_list_roots_callback( + context: RequestContext["ClientSession", Any], + ) -> types.ListRootsResult | types.ErrorData: + return types.ListRootsResult(roots=[]) + + async def mock_server(): + nonlocal received_capabilities + + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + request = ClientRequest.model_validate( + jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + assert isinstance(request.root, InitializeRequest) + received_capabilities = request.root.params.capabilities + + result = ServerResult( + InitializeResult( + protocolVersion=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + serverInfo=Implementation(name="mock-server", version="0.1.0"), + ) + ) + + async with server_to_client_send: + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=result.model_dump( + by_alias=True, mode="json", exclude_none=True + ), + ) + ) + ) + ) + # Receive initialized notification + await client_to_server_receive.receive() + + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + sampling_callback=custom_sampling_callback, + list_roots_callback=custom_list_roots_callback, + ) as session, + anyio.create_task_group() as tg, + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, + ): + tg.start_soon(mock_server) + await session.initialize() + + # Assert that capabilities are properly set with custom callbacks + assert received_capabilities is not None + assert ( + received_capabilities.sampling is not None + ) # Custom sampling callback provided + assert isinstance(received_capabilities.sampling, types.SamplingCapability) + assert ( + received_capabilities.roots is not None + ) # Custom list_roots callback provided + assert isinstance(received_capabilities.roots, types.RootsCapability) + assert ( + received_capabilities.roots.listChanged is True + ) # Should be True for custom callback