From feaf2423ba8b114cd22cc68f5fe45e75229ad0d5 Mon Sep 17 00:00:00 2001 From: Inna Date: Tue, 18 Mar 2025 22:00:14 +0000 Subject: [PATCH] add callback for logging message notification --- src/mcp/client/session.py | 25 ++++++++ src/mcp/shared/memory.py | 4 +- tests/client/test_logging_callback.py | 85 +++++++++++++++++++++++++++ 3 files changed, 113 insertions(+), 1 deletion(-) create mode 100644 tests/client/test_logging_callback.py diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 2ac24877..8acf3295 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -24,6 +24,13 @@ async def __call__( ) -> types.ListRootsResult | types.ErrorData: ... +class LoggingFnT(Protocol): + async def __call__( + self, + params: types.LoggingMessageNotificationParams, + ) -> None: ... + + async def _default_sampling_callback( context: RequestContext["ClientSession", Any], params: types.CreateMessageRequestParams, @@ -43,6 +50,12 @@ async def _default_list_roots_callback( ) +async def _default_logging_callback( + params: types.LoggingMessageNotificationParams, +) -> None: + pass + + ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter( types.ClientResult | types.ErrorData ) @@ -64,6 +77,7 @@ def __init__( read_timeout_seconds: timedelta | None = None, sampling_callback: SamplingFnT | None = None, list_roots_callback: ListRootsFnT | None = None, + logging_callback: LoggingFnT | None = None, ) -> None: super().__init__( read_stream, @@ -74,6 +88,7 @@ def __init__( ) self._sampling_callback = sampling_callback or _default_sampling_callback self._list_roots_callback = list_roots_callback or _default_list_roots_callback + self._logging_callback = logging_callback or _default_logging_callback async def initialize(self) -> types.InitializeResult: sampling = types.SamplingCapability() @@ -321,3 +336,13 @@ async def _received_request( return await responder.respond( types.ClientResult(root=types.EmptyResult()) ) + + async def _received_notification( + self, notification: types.ServerNotification + ) -> None: + """Handle notifications from the server.""" + match notification.root: + case types.LoggingMessageNotification(params=params): + await self._logging_callback(params) + case _: + pass diff --git a/src/mcp/shared/memory.py b/src/mcp/shared/memory.py index ae6b0be5..495f0c1e 100644 --- a/src/mcp/shared/memory.py +++ b/src/mcp/shared/memory.py @@ -9,7 +9,7 @@ import anyio from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream -from mcp.client.session import ClientSession, ListRootsFnT, SamplingFnT +from mcp.client.session import ClientSession, ListRootsFnT, LoggingFnT, SamplingFnT from mcp.server import Server from mcp.types import JSONRPCMessage @@ -56,6 +56,7 @@ async def create_connected_server_and_client_session( read_timeout_seconds: timedelta | None = None, sampling_callback: SamplingFnT | None = None, list_roots_callback: ListRootsFnT | None = None, + logging_callback: LoggingFnT | None = None, raise_exceptions: bool = False, ) -> AsyncGenerator[ClientSession, None]: """Creates a ClientSession that is connected to a running MCP server.""" @@ -84,6 +85,7 @@ async def create_connected_server_and_client_session( read_timeout_seconds=read_timeout_seconds, sampling_callback=sampling_callback, list_roots_callback=list_roots_callback, + logging_callback=logging_callback, ) as client_session: await client_session.initialize() yield client_session diff --git a/tests/client/test_logging_callback.py b/tests/client/test_logging_callback.py new file mode 100644 index 00000000..ead4f092 --- /dev/null +++ b/tests/client/test_logging_callback.py @@ -0,0 +1,85 @@ +from typing import List, Literal + +import anyio +import pytest + +from mcp.shared.memory import ( + create_connected_server_and_client_session as create_session, +) +from mcp.types import ( + LoggingMessageNotificationParams, + TextContent, +) + + +class LoggingCollector: + def __init__(self): + self.log_messages: List[LoggingMessageNotificationParams] = [] + + async def __call__(self, params: LoggingMessageNotificationParams) -> None: + self.log_messages.append(params) + + +@pytest.mark.anyio +async def test_logging_callback(): + from mcp.server.fastmcp import FastMCP + + server = FastMCP("test") + logging_collector = LoggingCollector() + + # Create a simple test tool + @server.tool("test_tool") + async def test_tool() -> bool: + # The actual tool is very simple and just returns True + return True + + # Create a function that can send a log notification + @server.tool("test_tool_with_log") + async def test_tool_with_log( + message: str, level: Literal["debug", "info", "warning", "error"], logger: str + ) -> bool: + """Send a log notification to the client.""" + await server.get_context().log( + level=level, + message=message, + logger_name=logger, + ) + return True + + async with anyio.create_task_group() as tg: + async with create_session( + server._mcp_server, logging_callback=logging_collector + ) as client_session: + + async def listen_session(): + try: + async for message in client_session.incoming_messages: + if isinstance(message, Exception): + raise message + except anyio.EndOfStream: + pass + + tg.start_soon(listen_session) + + # First verify our test tool works + result = await client_session.call_tool("test_tool", {}) + assert result.isError is False + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "true" + + # Now send a log message via our tool + log_result = await client_session.call_tool( + "test_tool_with_log", + { + "message": "Test log message", + "level": "info", + "logger": "test_logger", + }, + ) + assert log_result.isError is False + assert len(logging_collector.log_messages) == 1 + assert logging_collector.log_messages[ + 0 + ] == LoggingMessageNotificationParams( + level="info", logger="test_logger", data="Test log message" + )