Skip to content

Commit 08f4e01

Browse files
authored
add callback for logging message notification (#314)
1 parent a9aca20 commit 08f4e01

File tree

3 files changed

+113
-1
lines changed

3 files changed

+113
-1
lines changed

src/mcp/client/session.py

+25
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,13 @@ async def __call__(
2424
) -> types.ListRootsResult | types.ErrorData: ...
2525

2626

27+
class LoggingFnT(Protocol):
28+
async def __call__(
29+
self,
30+
params: types.LoggingMessageNotificationParams,
31+
) -> None: ...
32+
33+
2734
async def _default_sampling_callback(
2835
context: RequestContext["ClientSession", Any],
2936
params: types.CreateMessageRequestParams,
@@ -43,6 +50,12 @@ async def _default_list_roots_callback(
4350
)
4451

4552

53+
async def _default_logging_callback(
54+
params: types.LoggingMessageNotificationParams,
55+
) -> None:
56+
pass
57+
58+
4659
ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(
4760
types.ClientResult | types.ErrorData
4861
)
@@ -64,6 +77,7 @@ def __init__(
6477
read_timeout_seconds: timedelta | None = None,
6578
sampling_callback: SamplingFnT | None = None,
6679
list_roots_callback: ListRootsFnT | None = None,
80+
logging_callback: LoggingFnT | None = None,
6781
) -> None:
6882
super().__init__(
6983
read_stream,
@@ -74,6 +88,7 @@ def __init__(
7488
)
7589
self._sampling_callback = sampling_callback or _default_sampling_callback
7690
self._list_roots_callback = list_roots_callback or _default_list_roots_callback
91+
self._logging_callback = logging_callback or _default_logging_callback
7792

7893
async def initialize(self) -> types.InitializeResult:
7994
sampling = types.SamplingCapability()
@@ -321,3 +336,13 @@ async def _received_request(
321336
return await responder.respond(
322337
types.ClientResult(root=types.EmptyResult())
323338
)
339+
340+
async def _received_notification(
341+
self, notification: types.ServerNotification
342+
) -> None:
343+
"""Handle notifications from the server."""
344+
match notification.root:
345+
case types.LoggingMessageNotification(params=params):
346+
await self._logging_callback(params)
347+
case _:
348+
pass

src/mcp/shared/memory.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import anyio
1010
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
1111

12-
from mcp.client.session import ClientSession, ListRootsFnT, SamplingFnT
12+
from mcp.client.session import ClientSession, ListRootsFnT, LoggingFnT, SamplingFnT
1313
from mcp.server import Server
1414
from mcp.types import JSONRPCMessage
1515

@@ -56,6 +56,7 @@ async def create_connected_server_and_client_session(
5656
read_timeout_seconds: timedelta | None = None,
5757
sampling_callback: SamplingFnT | None = None,
5858
list_roots_callback: ListRootsFnT | None = None,
59+
logging_callback: LoggingFnT | None = None,
5960
raise_exceptions: bool = False,
6061
) -> AsyncGenerator[ClientSession, None]:
6162
"""Creates a ClientSession that is connected to a running MCP server."""
@@ -84,6 +85,7 @@ async def create_connected_server_and_client_session(
8485
read_timeout_seconds=read_timeout_seconds,
8586
sampling_callback=sampling_callback,
8687
list_roots_callback=list_roots_callback,
88+
logging_callback=logging_callback,
8789
) as client_session:
8890
await client_session.initialize()
8991
yield client_session

tests/client/test_logging_callback.py

+85
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
from typing import List, Literal
2+
3+
import anyio
4+
import pytest
5+
6+
from mcp.shared.memory import (
7+
create_connected_server_and_client_session as create_session,
8+
)
9+
from mcp.types import (
10+
LoggingMessageNotificationParams,
11+
TextContent,
12+
)
13+
14+
15+
class LoggingCollector:
16+
def __init__(self):
17+
self.log_messages: List[LoggingMessageNotificationParams] = []
18+
19+
async def __call__(self, params: LoggingMessageNotificationParams) -> None:
20+
self.log_messages.append(params)
21+
22+
23+
@pytest.mark.anyio
24+
async def test_logging_callback():
25+
from mcp.server.fastmcp import FastMCP
26+
27+
server = FastMCP("test")
28+
logging_collector = LoggingCollector()
29+
30+
# Create a simple test tool
31+
@server.tool("test_tool")
32+
async def test_tool() -> bool:
33+
# The actual tool is very simple and just returns True
34+
return True
35+
36+
# Create a function that can send a log notification
37+
@server.tool("test_tool_with_log")
38+
async def test_tool_with_log(
39+
message: str, level: Literal["debug", "info", "warning", "error"], logger: str
40+
) -> bool:
41+
"""Send a log notification to the client."""
42+
await server.get_context().log(
43+
level=level,
44+
message=message,
45+
logger_name=logger,
46+
)
47+
return True
48+
49+
async with anyio.create_task_group() as tg:
50+
async with create_session(
51+
server._mcp_server, logging_callback=logging_collector
52+
) as client_session:
53+
54+
async def listen_session():
55+
try:
56+
async for message in client_session.incoming_messages:
57+
if isinstance(message, Exception):
58+
raise message
59+
except anyio.EndOfStream:
60+
pass
61+
62+
tg.start_soon(listen_session)
63+
64+
# First verify our test tool works
65+
result = await client_session.call_tool("test_tool", {})
66+
assert result.isError is False
67+
assert isinstance(result.content[0], TextContent)
68+
assert result.content[0].text == "true"
69+
70+
# Now send a log message via our tool
71+
log_result = await client_session.call_tool(
72+
"test_tool_with_log",
73+
{
74+
"message": "Test log message",
75+
"level": "info",
76+
"logger": "test_logger",
77+
},
78+
)
79+
assert log_result.isError is False
80+
assert len(logging_collector.log_messages) == 1
81+
assert logging_collector.log_messages[
82+
0
83+
] == LoggingMessageNotificationParams(
84+
level="info", logger="test_logger", data="Test log message"
85+
)

0 commit comments

Comments
 (0)