From fc9e388f8be28e7a579eb58e0ed7686fdd63e017 Mon Sep 17 00:00:00 2001 From: Jerome Date: Thu, 10 Apr 2025 13:33:31 +0100 Subject: [PATCH 1/3] Added clientInfo to client session init args --- src/mcp/client/session.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 65d5e11e..2da75d8d 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -11,6 +11,8 @@ from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS +_DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0") + class SamplingFnT(Protocol): async def __call__( self, @@ -97,6 +99,7 @@ def __init__( list_roots_callback: ListRootsFnT | None = None, logging_callback: LoggingFnT | None = None, message_handler: MessageHandlerFnT | None = None, + client_info: types.Implementation | None = None, ) -> None: super().__init__( read_stream, @@ -105,6 +108,7 @@ def __init__( types.ServerNotification, read_timeout_seconds=read_timeout_seconds, ) + self._client_info = client_info or _DEFAULT_CLIENT_INFO self._sampling_callback = sampling_callback or _default_sampling_callback self._list_roots_callback = list_roots_callback or _default_list_roots_callback self._logging_callback = logging_callback or _default_logging_callback @@ -130,7 +134,7 @@ async def initialize(self) -> types.InitializeResult: experimental=None, roots=roots, ), - clientInfo=types.Implementation(name="mcp", version="0.1.0"), + clientInfo=self._client_info, ), ) ), From 36b2b19bc3c882d16bba2fd2324bf42a4b32d5df Mon Sep 17 00:00:00 2001 From: Jerome Date: Thu, 10 Apr 2025 13:45:48 +0100 Subject: [PATCH 2/3] Add tests and propagate client_info parameter throughout client APIs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- src/mcp/client/__main__.py | 6 +- src/mcp/client/session.py | 2 +- src/mcp/shared/memory.py | 3 + tests/client/test_session.py | 130 ++++++++++++++++++++++++++++++++++- 4 files changed, 138 insertions(+), 3 deletions(-) diff --git a/src/mcp/client/__main__.py b/src/mcp/client/__main__.py index 39b4f45c..84e15bd5 100644 --- a/src/mcp/client/__main__.py +++ b/src/mcp/client/__main__.py @@ -38,9 +38,13 @@ async def message_handler( async def run_session( read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception], write_stream: MemoryObjectSendStream[JSONRPCMessage], + client_info: types.Implementation | None = None, ): async with ClientSession( - read_stream, write_stream, message_handler=message_handler + read_stream, + write_stream, + message_handler=message_handler, + client_info=client_info, ) as session: logger.info("Initializing session") await session.initialize() diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 2da75d8d..fda9aee3 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -10,9 +10,9 @@ from mcp.shared.session import BaseSession, RequestResponder from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS - _DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0") + class SamplingFnT(Protocol): async def __call__( self, diff --git a/src/mcp/shared/memory.py b/src/mcp/shared/memory.py index 346f6156..abf87a3a 100644 --- a/src/mcp/shared/memory.py +++ b/src/mcp/shared/memory.py @@ -10,6 +10,7 @@ import anyio from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +import mcp.types as types from mcp.client.session import ( ClientSession, ListRootsFnT, @@ -65,6 +66,7 @@ async def create_connected_server_and_client_session( list_roots_callback: ListRootsFnT | None = None, logging_callback: LoggingFnT | None = None, message_handler: MessageHandlerFnT | None = None, + client_info: types.Implementation | None = None, raise_exceptions: bool = False, ) -> AsyncGenerator[ClientSession, None]: """Creates a ClientSession that is connected to a running MCP server.""" @@ -95,6 +97,7 @@ async def create_connected_server_and_client_session( list_roots_callback=list_roots_callback, logging_callback=logging_callback, message_handler=message_handler, + client_info=client_info, ) as client_session: await client_session.initialize() yield client_session diff --git a/tests/client/test_session.py b/tests/client/test_session.py index f250a05b..325bc4eb 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -2,7 +2,7 @@ import pytest import mcp.types as types -from mcp.client.session import ClientSession +from mcp.client.session import _DEFAULT_CLIENT_INFO, ClientSession from mcp.shared.session import RequestResponder from mcp.types import ( LATEST_PROTOCOL_VERSION, @@ -111,3 +111,131 @@ async def message_handler( # Check that the client sent the initialized notification assert initialized_notification assert isinstance(initialized_notification.root, InitializedNotification) + + +@pytest.mark.anyio +async def test_client_session_custom_client_info(): + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ + JSONRPCMessage + ](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ + JSONRPCMessage + ](1) + + custom_client_info = Implementation(name="test-client", version="1.2.3") + received_client_info = None + + async def mock_server(): + nonlocal received_client_info + + jsonrpc_request = await client_to_server_receive.receive() + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + request = ClientRequest.model_validate( + jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + assert isinstance(request.root, InitializeRequest) + received_client_info = request.root.params.clientInfo + + result = ServerResult( + InitializeResult( + protocolVersion=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + serverInfo=Implementation(name="mock-server", version="0.1.0"), + ) + ) + + async with server_to_client_send: + await server_to_client_send.send( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=result.model_dump( + by_alias=True, mode="json", exclude_none=True + ), + ) + ) + ) + # Receive initialized notification + await client_to_server_receive.receive() + + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + client_info=custom_client_info, + ) as session, + anyio.create_task_group() as tg, + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, + ): + tg.start_soon(mock_server) + await session.initialize() + + # Assert that the custom client info was sent + assert received_client_info == custom_client_info + + +@pytest.mark.anyio +async def test_client_session_default_client_info(): + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ + JSONRPCMessage + ](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ + JSONRPCMessage + ](1) + + received_client_info = None + + async def mock_server(): + nonlocal received_client_info + + jsonrpc_request = await client_to_server_receive.receive() + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + request = ClientRequest.model_validate( + jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + assert isinstance(request.root, InitializeRequest) + received_client_info = request.root.params.clientInfo + + result = ServerResult( + InitializeResult( + protocolVersion=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + serverInfo=Implementation(name="mock-server", version="0.1.0"), + ) + ) + + async with server_to_client_send: + await server_to_client_send.send( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=result.model_dump( + by_alias=True, mode="json", exclude_none=True + ), + ) + ) + ) + # Receive initialized notification + await client_to_server_receive.receive() + + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + ) as session, + anyio.create_task_group() as tg, + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, + ): + tg.start_soon(mock_server) + await session.initialize() + + # Assert that the default client info was sent + assert received_client_info == _DEFAULT_CLIENT_INFO From cd236f6e62fb5ae49f0e7647bab2d24d44458407 Mon Sep 17 00:00:00 2001 From: Jerome Date: Thu, 10 Apr 2025 13:47:36 +0100 Subject: [PATCH 3/3] Make DEFAULT_CLIENT_INFO public MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- src/mcp/client/session.py | 4 ++-- tests/client/test_session.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index fda9aee3..e29797d1 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -10,7 +10,7 @@ from mcp.shared.session import BaseSession, RequestResponder from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS -_DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0") +DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0") class SamplingFnT(Protocol): @@ -108,7 +108,7 @@ def __init__( types.ServerNotification, read_timeout_seconds=read_timeout_seconds, ) - self._client_info = client_info or _DEFAULT_CLIENT_INFO + self._client_info = client_info or DEFAULT_CLIENT_INFO self._sampling_callback = sampling_callback or _default_sampling_callback self._list_roots_callback = list_roots_callback or _default_list_roots_callback self._logging_callback = logging_callback or _default_logging_callback diff --git a/tests/client/test_session.py b/tests/client/test_session.py index 325bc4eb..543ebb2f 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -2,7 +2,7 @@ import pytest import mcp.types as types -from mcp.client.session import _DEFAULT_CLIENT_INFO, ClientSession +from mcp.client.session import DEFAULT_CLIENT_INFO, ClientSession from mcp.shared.session import RequestResponder from mcp.types import ( LATEST_PROTOCOL_VERSION, @@ -238,4 +238,4 @@ async def mock_server(): await session.initialize() # Assert that the default client info was sent - assert received_client_info == _DEFAULT_CLIENT_INFO + assert received_client_info == DEFAULT_CLIENT_INFO