diff --git a/src/mcp/client/__main__.py b/src/mcp/client/__main__.py index baf815c0..39b4f45c 100644 --- a/src/mcp/client/__main__.py +++ b/src/mcp/client/__main__.py @@ -7,9 +7,11 @@ import anyio from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +import mcp.types as types from mcp.client.session import ClientSession from mcp.client.sse import sse_client from mcp.client.stdio import StdioServerParameters, stdio_client +from mcp.shared.session import RequestResponder from mcp.types import JSONRPCMessage if not sys.warnoptions: @@ -21,26 +23,25 @@ logger = logging.getLogger("client") -async def receive_loop(session: ClientSession): - logger.info("Starting receive loop") - async for message in session.incoming_messages: - if isinstance(message, Exception): - logger.error("Error: %s", message) - continue +async def message_handler( + message: RequestResponder[types.ServerRequest, types.ClientResult] + | types.ServerNotification + | Exception, +) -> None: + if isinstance(message, Exception): + logger.error("Error: %s", message) + return - logger.info("Received message from server: %s", message) + logger.info("Received message from server: %s", message) async def run_session( read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception], write_stream: MemoryObjectSendStream[JSONRPCMessage], ): - async with ( - ClientSession(read_stream, write_stream) as session, - anyio.create_task_group() as tg, - ): - tg.start_soon(receive_loop, session) - + async with ClientSession( + read_stream, write_stream, message_handler=message_handler + ) as session: logger.info("Initializing session") await session.initialize() logger.info("Initialized") diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 8acf3295..65d5e11e 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -1,6 +1,7 @@ from datetime import timedelta from typing import Any, Protocol +import anyio.lowlevel from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import AnyUrl, TypeAdapter @@ -31,6 +32,23 @@ async def __call__( ) -> None: ... +class MessageHandlerFnT(Protocol): + async def __call__( + self, + message: RequestResponder[types.ServerRequest, types.ClientResult] + | types.ServerNotification + | Exception, + ) -> None: ... + + +async def _default_message_handler( + message: RequestResponder[types.ServerRequest, types.ClientResult] + | types.ServerNotification + | Exception, +) -> None: + await anyio.lowlevel.checkpoint() + + async def _default_sampling_callback( context: RequestContext["ClientSession", Any], params: types.CreateMessageRequestParams, @@ -78,6 +96,7 @@ def __init__( sampling_callback: SamplingFnT | None = None, list_roots_callback: ListRootsFnT | None = None, logging_callback: LoggingFnT | None = None, + message_handler: MessageHandlerFnT | None = None, ) -> None: super().__init__( read_stream, @@ -89,6 +108,7 @@ def __init__( self._sampling_callback = sampling_callback or _default_sampling_callback self._list_roots_callback = list_roots_callback or _default_list_roots_callback self._logging_callback = logging_callback or _default_logging_callback + self._message_handler = message_handler or _default_message_handler async def initialize(self) -> types.InitializeResult: sampling = types.SamplingCapability() @@ -337,10 +357,20 @@ async def _received_request( types.ClientResult(root=types.EmptyResult()) ) + async def _handle_incoming( + self, + req: RequestResponder[types.ServerRequest, types.ClientResult] + | types.ServerNotification + | Exception, + ) -> None: + """Handle incoming messages by forwarding to the message handler.""" + await self._message_handler(req) + async def _received_notification( self, notification: types.ServerNotification ) -> None: """Handle notifications from the server.""" + # Process specific notification types match notification.root: case types.LoggingMessageNotification(params=params): await self._logging_callback(params) diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 788bb9f8..568ecd4b 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -61,6 +61,12 @@ class InitializationState(Enum): ServerSessionT = TypeVar("ServerSessionT", bound="ServerSession") +ServerRequestResponder = ( + RequestResponder[types.ClientRequest, types.ServerResult] + | types.ClientNotification + | Exception +) + class ServerSession( BaseSession[ @@ -85,6 +91,15 @@ def __init__( ) self._initialization_state = InitializationState.NotInitialized self._init_options = init_options + self._incoming_message_stream_writer, self._incoming_message_stream_reader = ( + anyio.create_memory_object_stream[ServerRequestResponder](0) + ) + self._exit_stack.push_async_callback( + lambda: self._incoming_message_stream_reader.aclose() + ) + self._exit_stack.push_async_callback( + lambda: self._incoming_message_stream_writer.aclose() + ) @property def client_params(self) -> types.InitializeRequestParams | None: @@ -291,3 +306,12 @@ async def send_prompt_list_changed(self) -> None: ) ) ) + + async def _handle_incoming(self, req: ServerRequestResponder) -> None: + await self._incoming_message_stream_writer.send(req) + + @property + def incoming_messages( + self, + ) -> MemoryObjectReceiveStream[ServerRequestResponder]: + return self._incoming_message_stream_reader diff --git a/src/mcp/shared/memory.py b/src/mcp/shared/memory.py index 938d4a30..346f6156 100644 --- a/src/mcp/shared/memory.py +++ b/src/mcp/shared/memory.py @@ -10,7 +10,13 @@ import anyio from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream -from mcp.client.session import ClientSession, ListRootsFnT, LoggingFnT, SamplingFnT +from mcp.client.session import ( + ClientSession, + ListRootsFnT, + LoggingFnT, + MessageHandlerFnT, + SamplingFnT, +) from mcp.server import Server from mcp.types import JSONRPCMessage @@ -58,6 +64,7 @@ async def create_connected_server_and_client_session( sampling_callback: SamplingFnT | None = None, list_roots_callback: ListRootsFnT | None = None, logging_callback: LoggingFnT | None = None, + message_handler: MessageHandlerFnT | None = None, raise_exceptions: bool = False, ) -> AsyncGenerator[ClientSession, None]: """Creates a ClientSession that is connected to a running MCP server.""" @@ -87,6 +94,7 @@ async def create_connected_server_and_client_session( sampling_callback=sampling_callback, list_roots_callback=list_roots_callback, logging_callback=logging_callback, + message_handler=message_handler, ) as client_session: await client_session.initialize() yield client_session diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 31c04df3..05fd3ce3 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -189,19 +189,6 @@ def __init__( self._in_flight = {} self._exit_stack = AsyncExitStack() - self._incoming_message_stream_writer, self._incoming_message_stream_reader = ( - anyio.create_memory_object_stream[ - RequestResponder[ReceiveRequestT, SendResultT] - | ReceiveNotificationT - | Exception - ]() - ) - self._exit_stack.push_async_callback( - lambda: self._incoming_message_stream_reader.aclose() - ) - self._exit_stack.push_async_callback( - lambda: self._incoming_message_stream_writer.aclose() - ) async def __aenter__(self) -> Self: self._task_group = anyio.create_task_group() @@ -312,11 +299,10 @@ async def _receive_loop(self) -> None: async with ( self._read_stream, self._write_stream, - self._incoming_message_stream_writer, ): async for message in self._read_stream: if isinstance(message, Exception): - await self._incoming_message_stream_writer.send(message) + await self._handle_incoming(message) elif isinstance(message.root, JSONRPCRequest): validated_request = self._receive_request_type.model_validate( message.root.model_dump( @@ -336,8 +322,9 @@ async def _receive_loop(self) -> None: self._in_flight[responder.request_id] = responder await self._received_request(responder) + if not responder._completed: # type: ignore[reportPrivateUsage] - await self._incoming_message_stream_writer.send(responder) + await self._handle_incoming(responder) elif isinstance(message.root, JSONRPCNotification): try: @@ -353,9 +340,7 @@ async def _receive_loop(self) -> None: await self._in_flight[cancelled_id].cancel() else: await self._received_notification(notification) - await self._incoming_message_stream_writer.send( - notification - ) + await self._handle_incoming(notification) except Exception as e: # For other validation errors, log and continue logging.warning( @@ -367,7 +352,7 @@ async def _receive_loop(self) -> None: if stream: await stream.send(message.root) else: - await self._incoming_message_stream_writer.send( + await self._handle_incoming( RuntimeError( "Received response with an unknown " f"request ID: {message}" @@ -399,12 +384,11 @@ async def send_progress_notification( processed. """ - @property - def incoming_messages( + async def _handle_incoming( self, - ) -> MemoryObjectReceiveStream[ - RequestResponder[ReceiveRequestT, SendResultT] + req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT - | Exception - ]: - return self._incoming_message_stream_reader + | Exception, + ) -> None: + """A generic handler for incoming messages. Overwritten by subclasses.""" + pass diff --git a/tests/client/test_logging_callback.py b/tests/client/test_logging_callback.py index 74f4b487..797f817e 100644 --- a/tests/client/test_logging_callback.py +++ b/tests/client/test_logging_callback.py @@ -1,11 +1,12 @@ from typing import Literal -import anyio import pytest +import mcp.types as types from mcp.shared.memory import ( create_connected_server_and_client_session as create_session, ) +from mcp.shared.session import RequestResponder from mcp.types import ( LoggingMessageNotificationParams, TextContent, @@ -46,40 +47,37 @@ async def test_tool_with_log( ) return True - async with anyio.create_task_group() as tg: - async with create_session( - server._mcp_server, logging_callback=logging_collector - ) as client_session: + # Create a message handler to catch exceptions + async def message_handler( + message: RequestResponder[types.ServerRequest, types.ClientResult] + | types.ServerNotification + | Exception, + ) -> None: + if isinstance(message, Exception): + raise message - async def listen_session(): - try: - async for message in client_session.incoming_messages: - if isinstance(message, Exception): - raise message - except anyio.EndOfStream: - pass + async with create_session( + server._mcp_server, + logging_callback=logging_collector, + message_handler=message_handler, + ) as client_session: + # First verify our test tool works + result = await client_session.call_tool("test_tool", {}) + assert result.isError is False + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "true" - tg.start_soon(listen_session) - - # First verify our test tool works - result = await client_session.call_tool("test_tool", {}) - assert result.isError is False - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "true" - - # Now send a log message via our tool - log_result = await client_session.call_tool( - "test_tool_with_log", - { - "message": "Test log message", - "level": "info", - "logger": "test_logger", - }, - ) - assert log_result.isError is False - assert len(logging_collector.log_messages) == 1 - assert logging_collector.log_messages[ - 0 - ] == LoggingMessageNotificationParams( - level="info", logger="test_logger", data="Test log message" - ) + # Now send a log message via our tool + log_result = await client_session.call_tool( + "test_tool_with_log", + { + "message": "Test log message", + "level": "info", + "logger": "test_logger", + }, + ) + assert log_result.isError is False + assert len(logging_collector.log_messages) == 1 + assert logging_collector.log_messages[0] == LoggingMessageNotificationParams( + level="info", logger="test_logger", data="Test log message" + ) diff --git a/tests/client/test_session.py b/tests/client/test_session.py index 7d579cda..f250a05b 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -1,7 +1,9 @@ import anyio import pytest +import mcp.types as types from mcp.client.session import ClientSession +from mcp.shared.session import RequestResponder from mcp.types import ( LATEST_PROTOCOL_VERSION, ClientNotification, @@ -75,13 +77,21 @@ async def mock_server(): ) ) - async def listen_session(): - async for message in session.incoming_messages: - if isinstance(message, Exception): - raise message + # Create a message handler to catch exceptions + async def message_handler( + message: RequestResponder[types.ServerRequest, types.ClientResult] + | types.ServerNotification + | Exception, + ) -> None: + if isinstance(message, Exception): + raise message async with ( - ClientSession(server_to_client_receive, client_to_server_send) as session, + ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=message_handler, + ) as session, anyio.create_task_group() as tg, client_to_server_send, client_to_server_receive, @@ -89,7 +99,6 @@ async def listen_session(): server_to_client_receive, ): tg.start_soon(mock_server) - tg.start_soon(listen_session) result = await session.initialize() # Assert the result diff --git a/tests/issues/test_88_random_error.py b/tests/issues/test_88_random_error.py index 0aac6608..88e41d66 100644 --- a/tests/issues/test_88_random_error.py +++ b/tests/issues/test_88_random_error.py @@ -6,6 +6,7 @@ import anyio import pytest +from anyio.abc import TaskStatus from mcp.client.session import ClientSession from mcp.server.lowlevel import Server @@ -54,15 +55,21 @@ async def slow_tool( return [TextContent(type="text", text=f"fast {request_count}")] return [TextContent(type="text", text=f"unknown {request_count}")] - async def server_handler(read_stream, write_stream): - await server.run( - read_stream, - write_stream, - server.create_initialization_options(), - raise_exceptions=True, - ) - - async def client(read_stream, write_stream): + async def server_handler( + read_stream, + write_stream, + task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED, + ): + with anyio.CancelScope() as scope: + task_status.started(scope) # type: ignore + await server.run( + read_stream, + write_stream, + server.create_initialization_options(), + raise_exceptions=True, + ) + + async def client(read_stream, write_stream, scope): # Use a timeout that's: # - Long enough for fast operations (>10ms) # - Short enough for slow operations (<200ms) @@ -90,22 +97,13 @@ async def client(read_stream, write_stream): # proving server is still responsive result = await session.call_tool("fast") assert result.content == [TextContent(type="text", text="fast 3")] + scope.cancel() # Run server and client in separate task groups to avoid cancellation server_writer, server_reader = anyio.create_memory_object_stream(1) client_writer, client_reader = anyio.create_memory_object_stream(1) - server_ready = anyio.Event() - - async def wrapped_server_handler(read_stream, write_stream): - server_ready.set() - await server_handler(read_stream, write_stream) - async with anyio.create_task_group() as tg: - tg.start_soon(wrapped_server_handler, server_reader, client_writer) - # Wait for server to start and initialize - with anyio.fail_after(1): # Timeout after 1 second - await server_ready.wait() + scope = await tg.start(server_handler, server_reader, client_writer) # Run client in a separate task to avoid cancellation - async with anyio.create_task_group() as client_tg: - client_tg.start_soon(client, client_reader, server_writer) + tg.start_soon(client, client_reader, server_writer, scope) diff --git a/tests/server/test_session.py b/tests/server/test_session.py index 333196c9..561a94b6 100644 --- a/tests/server/test_session.py +++ b/tests/server/test_session.py @@ -1,11 +1,13 @@ import anyio import pytest +import mcp.types as types from mcp.client.session import ClientSession from mcp.server import Server from mcp.server.lowlevel import NotificationOptions from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession +from mcp.shared.session import RequestResponder from mcp.types import ( ClientNotification, InitializedNotification, @@ -25,10 +27,14 @@ async def test_server_session_initialize(): JSONRPCMessage ](1) - async def run_client(client: ClientSession): - async for message in client_session.incoming_messages: - if isinstance(message, Exception): - raise message + # Create a message handler to catch exceptions + async def message_handler( + message: RequestResponder[types.ServerRequest, types.ClientResult] + | types.ServerNotification + | Exception, + ) -> None: + if isinstance(message, Exception): + raise message received_initialized = False @@ -57,11 +63,12 @@ async def run_server(): try: async with ( ClientSession( - server_to_client_receive, client_to_server_send + server_to_client_receive, + client_to_server_send, + message_handler=message_handler, ) as client_session, anyio.create_task_group() as tg, ): - tg.start_soon(run_client, client_session) tg.start_soon(run_server) await client_session.initialize()