Skip to content

Commit 758682e

Browse files
dsp-antClaude
and
Claude
committed
Handle message callbacks in ClientSession
This change adds a message_handler callback to ClientSession to allow for direct handling of incoming messages instead of requiring an async iterator. The change simplifies the client code by removing the need for a separate receive loop task. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]>
1 parent 85514b4 commit 758682e

File tree

10 files changed

+190
-86
lines changed

10 files changed

+190
-86
lines changed

src/mcp/client/__main__.py

+14-13
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@
77
import anyio
88
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
99

10+
import mcp.types as types
1011
from mcp.client.session import ClientSession
1112
from mcp.client.sse import sse_client
1213
from mcp.client.stdio import StdioServerParameters, stdio_client
14+
from mcp.shared.session import RequestResponder
1315
from mcp.types import JSONRPCMessage
1416

1517
if not sys.warnoptions:
@@ -21,26 +23,25 @@
2123
logger = logging.getLogger("client")
2224

2325

24-
async def receive_loop(session: ClientSession):
25-
logger.info("Starting receive loop")
26-
async for message in session.incoming_messages:
27-
if isinstance(message, Exception):
28-
logger.error("Error: %s", message)
29-
continue
26+
async def message_handler(
27+
message: RequestResponder[types.ServerRequest, types.ClientResult]
28+
| types.ServerNotification
29+
| Exception,
30+
) -> None:
31+
if isinstance(message, Exception):
32+
logger.error("Error: %s", message)
33+
return
3034

31-
logger.info("Received message from server: %s", message)
35+
logger.info("Received message from server: %s", message)
3236

3337

3438
async def run_session(
3539
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
3640
write_stream: MemoryObjectSendStream[JSONRPCMessage],
3741
):
38-
async with (
39-
ClientSession(read_stream, write_stream) as session,
40-
anyio.create_task_group() as tg,
41-
):
42-
tg.start_soon(receive_loop, session)
43-
42+
async with ClientSession(
43+
read_stream, write_stream, message_handler=message_handler
44+
) as session:
4445
logger.info("Initializing session")
4546
await session.initialize()
4647
logger.info("Initialized")

src/mcp/client/session.py

+30
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from datetime import timedelta
22
from typing import Any, Protocol
33

4+
import anyio.lowlevel
45
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
56
from pydantic import AnyUrl, TypeAdapter
67

@@ -31,6 +32,23 @@ async def __call__(
3132
) -> None: ...
3233

3334

35+
class MessageHandlerFnT(Protocol):
36+
async def __call__(
37+
self,
38+
message: RequestResponder[types.ServerRequest, types.ClientResult]
39+
| types.ServerNotification
40+
| Exception,
41+
) -> None: ...
42+
43+
44+
async def _default_message_handler(
45+
message: RequestResponder[types.ServerRequest, types.ClientResult]
46+
| types.ServerNotification
47+
| Exception,
48+
) -> None:
49+
await anyio.lowlevel.checkpoint()
50+
51+
3452
async def _default_sampling_callback(
3553
context: RequestContext["ClientSession", Any],
3654
params: types.CreateMessageRequestParams,
@@ -78,6 +96,7 @@ def __init__(
7896
sampling_callback: SamplingFnT | None = None,
7997
list_roots_callback: ListRootsFnT | None = None,
8098
logging_callback: LoggingFnT | None = None,
99+
message_handler: MessageHandlerFnT | None = None,
81100
) -> None:
82101
super().__init__(
83102
read_stream,
@@ -89,6 +108,7 @@ def __init__(
89108
self._sampling_callback = sampling_callback or _default_sampling_callback
90109
self._list_roots_callback = list_roots_callback or _default_list_roots_callback
91110
self._logging_callback = logging_callback or _default_logging_callback
111+
self._message_handler = message_handler or _default_message_handler
92112

93113
async def initialize(self) -> types.InitializeResult:
94114
sampling = types.SamplingCapability()
@@ -337,10 +357,20 @@ async def _received_request(
337357
types.ClientResult(root=types.EmptyResult())
338358
)
339359

360+
async def _handle_incoming(
361+
self,
362+
req: RequestResponder[types.ServerRequest, types.ClientResult]
363+
| types.ServerNotification
364+
| Exception,
365+
) -> None:
366+
"""Handle incoming messages by forwarding to the message handler."""
367+
await self._message_handler(req)
368+
340369
async def _received_notification(
341370
self, notification: types.ServerNotification
342371
) -> None:
343372
"""Handle notifications from the server."""
373+
# Process specific notification types
344374
match notification.root:
345375
case types.LoggingMessageNotification(params=params):
346376
await self._logging_callback(params)

src/mcp/server/session.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def __init__(
9292
self._initialization_state = InitializationState.NotInitialized
9393
self._init_options = init_options
9494
self._incoming_message_stream_writer, self._incoming_message_stream_reader = (
95-
anyio.create_memory_object_stream[ServerRequestResponder]()
95+
anyio.create_memory_object_stream[ServerRequestResponder](0)
9696
)
9797
self._exit_stack.push_async_callback(
9898
lambda: self._incoming_message_stream_reader.aclose()
@@ -308,7 +308,7 @@ async def send_prompt_list_changed(self) -> None:
308308
)
309309

310310
async def _handle_incoming(self, req: ServerRequestResponder) -> None:
311-
return await self._incoming_message_stream_writer.send(req)
311+
await self._incoming_message_stream_writer.send(req)
312312

313313
@property
314314
def incoming_messages(

src/mcp/shared/memory.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,13 @@
99
import anyio
1010
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
1111

12-
from mcp.client.session import ClientSession, ListRootsFnT, LoggingFnT, SamplingFnT
12+
from mcp.client.session import (
13+
ClientSession,
14+
ListRootsFnT,
15+
LoggingFnT,
16+
MessageHandlerFnT,
17+
SamplingFnT,
18+
)
1319
from mcp.server import Server
1420
from mcp.types import JSONRPCMessage
1521

@@ -57,6 +63,7 @@ async def create_connected_server_and_client_session(
5763
sampling_callback: SamplingFnT | None = None,
5864
list_roots_callback: ListRootsFnT | None = None,
5965
logging_callback: LoggingFnT | None = None,
66+
message_handler: MessageHandlerFnT | None = None,
6067
raise_exceptions: bool = False,
6168
) -> AsyncGenerator[ClientSession, None]:
6269
"""Creates a ClientSession that is connected to a running MCP server."""
@@ -86,6 +93,7 @@ async def create_connected_server_and_client_session(
8693
sampling_callback=sampling_callback,
8794
list_roots_callback=list_roots_callback,
8895
logging_callback=logging_callback,
96+
message_handler=message_handler,
8997
) as client_session:
9098
await client_session.initialize()
9199
yield client_session

src/mcp/shared/session.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -378,4 +378,4 @@ async def _handle_incoming(
378378
| Exception,
379379
) -> None:
380380
"""A generic handler for incoming messages. Overwritten by subclasses."""
381-
await anyio.lowlevel.checkpoint()
381+
pass

test.py

+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import logging
2+
3+
from mcp import ClientSession, StdioServerParameters
4+
from mcp.client.sse import sse_client
5+
from mcp.client.stdio import stdio_client
6+
7+
# disable httpx logs
8+
logging.basicConfig(level=logging.DEBUG)
9+
logging.getLogger("httpcore.http11").setLevel(logging.CRITICAL)
10+
logging.getLogger("httpx").setLevel(logging.CRITICAL)
11+
logging.getLogger("urllib3.connectionpool").setLevel(logging.CRITICAL)
12+
13+
code = """
14+
import numpy
15+
a = numpy.array([1, 2, 3])
16+
print(a)
17+
a
18+
"""
19+
20+
21+
async def call_tools(session: ClientSession):
22+
# await session.initialize()
23+
await session.set_logging_level("debug")
24+
tools = await session.list_tools()
25+
print(f"{tools=}")
26+
result = await session.call_tool("run_python_code", {"python_code": code})
27+
print(f"{result=}")
28+
29+
30+
async def sse():
31+
async with sse_client("http://localhost:3001/sse") as (read, write):
32+
async with ClientSession(read, write) as session:
33+
await call_tools(session)
34+
35+
36+
async def stdio():
37+
server_params = StdioServerParameters(
38+
command="npx",
39+
args=[
40+
"--registry=https://registry.npmjs.org",
41+
"@pydantic/mcp-run-python",
42+
"stdio",
43+
],
44+
)
45+
async with stdio_client(server_params) as (read, write):
46+
async with ClientSession(read, write) as session:
47+
await call_tools(session)
48+
49+
50+
if __name__ == "__main__":
51+
import asyncio
52+
53+
asyncio.run(sse())

tests/client/test_logging_callback.py

+34-36
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from typing import List, Literal
22

3-
import anyio
43
import pytest
54

5+
import mcp.types as types
66
from mcp.shared.memory import (
77
create_connected_server_and_client_session as create_session,
88
)
9+
from mcp.shared.session import RequestResponder
910
from mcp.types import (
1011
LoggingMessageNotificationParams,
1112
TextContent,
@@ -46,40 +47,37 @@ async def test_tool_with_log(
4647
)
4748
return True
4849

49-
async with anyio.create_task_group() as tg:
50-
async with create_session(
51-
server._mcp_server, logging_callback=logging_collector
52-
) as client_session:
50+
# Create a message handler to catch exceptions
51+
async def message_handler(
52+
message: RequestResponder[types.ServerRequest, types.ClientResult]
53+
| types.ServerNotification
54+
| Exception,
55+
) -> None:
56+
if isinstance(message, Exception):
57+
raise message
5358

54-
async def listen_session():
55-
try:
56-
async for message in client_session.incoming_messages:
57-
if isinstance(message, Exception):
58-
raise message
59-
except anyio.EndOfStream:
60-
pass
59+
async with create_session(
60+
server._mcp_server,
61+
logging_callback=logging_collector,
62+
message_handler=message_handler,
63+
) as client_session:
64+
# First verify our test tool works
65+
result = await client_session.call_tool("test_tool", {})
66+
assert result.isError is False
67+
assert isinstance(result.content[0], TextContent)
68+
assert result.content[0].text == "true"
6169

62-
tg.start_soon(listen_session)
63-
64-
# First verify our test tool works
65-
result = await client_session.call_tool("test_tool", {})
66-
assert result.isError is False
67-
assert isinstance(result.content[0], TextContent)
68-
assert result.content[0].text == "true"
69-
70-
# Now send a log message via our tool
71-
log_result = await client_session.call_tool(
72-
"test_tool_with_log",
73-
{
74-
"message": "Test log message",
75-
"level": "info",
76-
"logger": "test_logger",
77-
},
78-
)
79-
assert log_result.isError is False
80-
assert len(logging_collector.log_messages) == 1
81-
assert logging_collector.log_messages[
82-
0
83-
] == LoggingMessageNotificationParams(
84-
level="info", logger="test_logger", data="Test log message"
85-
)
70+
# Now send a log message via our tool
71+
log_result = await client_session.call_tool(
72+
"test_tool_with_log",
73+
{
74+
"message": "Test log message",
75+
"level": "info",
76+
"logger": "test_logger",
77+
},
78+
)
79+
assert log_result.isError is False
80+
assert len(logging_collector.log_messages) == 1
81+
assert logging_collector.log_messages[0] == LoggingMessageNotificationParams(
82+
level="info", logger="test_logger", data="Test log message"
83+
)

tests/client/test_session.py

+15-6
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import anyio
22
import pytest
33

4+
import mcp.types as types
45
from mcp.client.session import ClientSession
6+
from mcp.shared.session import RequestResponder
57
from mcp.types import (
68
LATEST_PROTOCOL_VERSION,
79
ClientNotification,
@@ -75,21 +77,28 @@ async def mock_server():
7577
)
7678
)
7779

78-
async def listen_session():
79-
async for message in session.incoming_messages:
80-
if isinstance(message, Exception):
81-
raise message
80+
# Create a message handler to catch exceptions
81+
async def message_handler(
82+
message: RequestResponder[types.ServerRequest, types.ClientResult]
83+
| types.ServerNotification
84+
| Exception,
85+
) -> None:
86+
if isinstance(message, Exception):
87+
raise message
8288

8389
async with (
84-
ClientSession(server_to_client_receive, client_to_server_send) as session,
90+
ClientSession(
91+
server_to_client_receive,
92+
client_to_server_send,
93+
message_handler=message_handler,
94+
) as session,
8595
anyio.create_task_group() as tg,
8696
client_to_server_send,
8797
client_to_server_receive,
8898
server_to_client_send,
8999
server_to_client_receive,
90100
):
91101
tg.start_soon(mock_server)
92-
tg.start_soon(listen_session)
93102
result = await session.initialize()
94103

95104
# Assert the result

0 commit comments

Comments
 (0)