From 6ecf19db68cd642a35f876cadd4f0bfb4df07b3a Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Fri, 14 Mar 2025 11:53:21 +0100 Subject: [PATCH] Strict types on the client side --- pyproject.toml | 1 + src/mcp/client/__main__.py | 7 ++++++- src/mcp/client/session.py | 18 ++++++------------ src/mcp/client/sse.py | 4 ++++ src/mcp/client/websocket.py | 5 +++++ src/mcp/shared/version.py | 2 +- 6 files changed, 23 insertions(+), 14 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f352de5a..69db82c9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -77,6 +77,7 @@ venvPath = "." venv = ".venv" strict = [ "src/mcp/server/fastmcp/tools/base.py", + "src/mcp/client/*.py" ] [tool.ruff.lint] diff --git a/src/mcp/client/__main__.py b/src/mcp/client/__main__.py index 8ce704ff..baf815c0 100644 --- a/src/mcp/client/__main__.py +++ b/src/mcp/client/__main__.py @@ -5,10 +5,12 @@ from urllib.parse import urlparse import anyio +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from mcp.client.session import ClientSession from mcp.client.sse import sse_client from mcp.client.stdio import StdioServerParameters, stdio_client +from mcp.types import JSONRPCMessage if not sys.warnoptions: import warnings @@ -29,7 +31,10 @@ async def receive_loop(session: ClientSession): logger.info("Received message from server: %s", message) -async def run_session(read_stream, write_stream): +async def run_session( + read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception], + write_stream: MemoryObjectSendStream[JSONRPCMessage], +): async with ( ClientSession(read_stream, write_stream) as session, anyio.create_task_group() as tg, diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index cde3103b..2ac24877 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -76,18 +76,12 @@ def __init__( self._list_roots_callback = list_roots_callback or _default_list_roots_callback async def initialize(self) -> types.InitializeResult: - sampling = ( - types.SamplingCapability() if self._sampling_callback is not None else None - ) - roots = ( - types.RootsCapability( - # TODO: Should this be based on whether we - # _will_ send notifications, or only whether - # they're supported? - listChanged=True, - ) - if self._list_roots_callback is not None - else None + sampling = types.SamplingCapability() + roots = types.RootsCapability( + # TODO: Should this be based on whether we + # _will_ send notifications, or only whether + # they're supported? + listChanged=True, ) result = await self.send_request( diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index abafacb9..4f6241a7 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -98,6 +98,10 @@ async def sse_reader( continue await read_stream_writer.send(message) + case _: + logger.warning( + f"Unknown SSE event: {sse.event}" + ) except Exception as exc: logger.error(f"Error in sse_reader: {exc}") await read_stream_writer.send(exc) diff --git a/src/mcp/client/websocket.py b/src/mcp/client/websocket.py index 3e73b020..9cf32296 100644 --- a/src/mcp/client/websocket.py +++ b/src/mcp/client/websocket.py @@ -39,6 +39,11 @@ async def websocket_client( # Create two in-memory streams: # - One for incoming messages (read_stream, written by ws_reader) # - One for outgoing messages (write_stream, read by ws_writer) + read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] + read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] + write_stream: MemoryObjectSendStream[types.JSONRPCMessage] + write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage] + read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) diff --git a/src/mcp/shared/version.py b/src/mcp/shared/version.py index 51bf3521..8fd13b99 100644 --- a/src/mcp/shared/version.py +++ b/src/mcp/shared/version.py @@ -1,3 +1,3 @@ from mcp.types import LATEST_PROTOCOL_VERSION -SUPPORTED_PROTOCOL_VERSIONS = [1, LATEST_PROTOCOL_VERSION] +SUPPORTED_PROTOCOL_VERSIONS: tuple[int, str] = (1, LATEST_PROTOCOL_VERSION)