Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add callback for logging message notification to client #314

Merged
merged 1 commit into from
Mar 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions src/mcp/client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@ async def __call__(
) -> types.ListRootsResult | types.ErrorData: ...


class LoggingFnT(Protocol):
async def __call__(
self,
params: types.LoggingMessageNotificationParams,
) -> None: ...


async def _default_sampling_callback(
context: RequestContext["ClientSession", Any],
params: types.CreateMessageRequestParams,
Expand All @@ -43,6 +50,12 @@ async def _default_list_roots_callback(
)


async def _default_logging_callback(
params: types.LoggingMessageNotificationParams,
) -> None:
pass


ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(
types.ClientResult | types.ErrorData
)
Expand All @@ -64,6 +77,7 @@ def __init__(
read_timeout_seconds: timedelta | None = None,
sampling_callback: SamplingFnT | None = None,
list_roots_callback: ListRootsFnT | None = None,
logging_callback: LoggingFnT | None = None,
) -> None:
super().__init__(
read_stream,
Expand All @@ -74,6 +88,7 @@ def __init__(
)
self._sampling_callback = sampling_callback or _default_sampling_callback
self._list_roots_callback = list_roots_callback or _default_list_roots_callback
self._logging_callback = logging_callback or _default_logging_callback

async def initialize(self) -> types.InitializeResult:
sampling = types.SamplingCapability()
Expand Down Expand Up @@ -321,3 +336,13 @@ async def _received_request(
return await responder.respond(
types.ClientResult(root=types.EmptyResult())
)

async def _received_notification(
self, notification: types.ServerNotification
) -> None:
"""Handle notifications from the server."""
match notification.root:
case types.LoggingMessageNotification(params=params):
await self._logging_callback(params)
case _:
pass
4 changes: 3 additions & 1 deletion src/mcp/shared/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream

from mcp.client.session import ClientSession, ListRootsFnT, SamplingFnT
from mcp.client.session import ClientSession, ListRootsFnT, LoggingFnT, SamplingFnT
from mcp.server import Server
from mcp.types import JSONRPCMessage

Expand Down Expand Up @@ -56,6 +56,7 @@ async def create_connected_server_and_client_session(
read_timeout_seconds: timedelta | None = None,
sampling_callback: SamplingFnT | None = None,
list_roots_callback: ListRootsFnT | None = None,
logging_callback: LoggingFnT | None = None,
raise_exceptions: bool = False,
) -> AsyncGenerator[ClientSession, None]:
"""Creates a ClientSession that is connected to a running MCP server."""
Expand Down Expand Up @@ -84,6 +85,7 @@ async def create_connected_server_and_client_session(
read_timeout_seconds=read_timeout_seconds,
sampling_callback=sampling_callback,
list_roots_callback=list_roots_callback,
logging_callback=logging_callback,
) as client_session:
await client_session.initialize()
yield client_session
Expand Down
85 changes: 85 additions & 0 deletions tests/client/test_logging_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from typing import List, Literal

import anyio
import pytest

from mcp.shared.memory import (
create_connected_server_and_client_session as create_session,
)
from mcp.types import (
LoggingMessageNotificationParams,
TextContent,
)


class LoggingCollector:
def __init__(self):
self.log_messages: List[LoggingMessageNotificationParams] = []

async def __call__(self, params: LoggingMessageNotificationParams) -> None:
self.log_messages.append(params)


@pytest.mark.anyio
async def test_logging_callback():
from mcp.server.fastmcp import FastMCP

server = FastMCP("test")
logging_collector = LoggingCollector()

# Create a simple test tool
@server.tool("test_tool")
async def test_tool() -> bool:
# The actual tool is very simple and just returns True
return True

# Create a function that can send a log notification
@server.tool("test_tool_with_log")
async def test_tool_with_log(
message: str, level: Literal["debug", "info", "warning", "error"], logger: str
) -> bool:
"""Send a log notification to the client."""
await server.get_context().log(
level=level,
message=message,
logger_name=logger,
)
return True

async with anyio.create_task_group() as tg:
async with create_session(
server._mcp_server, logging_callback=logging_collector
) as client_session:

async def listen_session():
try:
async for message in client_session.incoming_messages:
if isinstance(message, Exception):
raise message
except anyio.EndOfStream:
pass

tg.start_soon(listen_session)

# First verify our test tool works
result = await client_session.call_tool("test_tool", {})
assert result.isError is False
assert isinstance(result.content[0], TextContent)
assert result.content[0].text == "true"

# Now send a log message via our tool
log_result = await client_session.call_tool(
"test_tool_with_log",
{
"message": "Test log message",
"level": "info",
"logger": "test_logger",
},
)
assert log_result.isError is False
assert len(logging_collector.log_messages) == 1
assert logging_collector.log_messages[
0
] == LoggingMessageNotificationParams(
level="info", logger="test_logger", data="Test log message"
)