Skip to content

Commit 568cbd1

Browse files
dsp-antClaude
and
Claude
authored
Fix #201: Move incoming message stream from BaseSession to ServerSession (#325)
Co-authored-by: Claude <[email protected]>
1 parent 9ae4df8 commit 568cbd1

File tree

9 files changed

+169
-110
lines changed

9 files changed

+169
-110
lines changed

src/mcp/client/__main__.py

+14-13
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@
77
import anyio
88
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
99

10+
import mcp.types as types
1011
from mcp.client.session import ClientSession
1112
from mcp.client.sse import sse_client
1213
from mcp.client.stdio import StdioServerParameters, stdio_client
14+
from mcp.shared.session import RequestResponder
1315
from mcp.types import JSONRPCMessage
1416

1517
if not sys.warnoptions:
@@ -21,26 +23,25 @@
2123
logger = logging.getLogger("client")
2224

2325

24-
async def receive_loop(session: ClientSession):
25-
logger.info("Starting receive loop")
26-
async for message in session.incoming_messages:
27-
if isinstance(message, Exception):
28-
logger.error("Error: %s", message)
29-
continue
26+
async def message_handler(
27+
message: RequestResponder[types.ServerRequest, types.ClientResult]
28+
| types.ServerNotification
29+
| Exception,
30+
) -> None:
31+
if isinstance(message, Exception):
32+
logger.error("Error: %s", message)
33+
return
3034

31-
logger.info("Received message from server: %s", message)
35+
logger.info("Received message from server: %s", message)
3236

3337

3438
async def run_session(
3539
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
3640
write_stream: MemoryObjectSendStream[JSONRPCMessage],
3741
):
38-
async with (
39-
ClientSession(read_stream, write_stream) as session,
40-
anyio.create_task_group() as tg,
41-
):
42-
tg.start_soon(receive_loop, session)
43-
42+
async with ClientSession(
43+
read_stream, write_stream, message_handler=message_handler
44+
) as session:
4445
logger.info("Initializing session")
4546
await session.initialize()
4647
logger.info("Initialized")

src/mcp/client/session.py

+30
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from datetime import timedelta
22
from typing import Any, Protocol
33

4+
import anyio.lowlevel
45
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
56
from pydantic import AnyUrl, TypeAdapter
67

@@ -31,6 +32,23 @@ async def __call__(
3132
) -> None: ...
3233

3334

35+
class MessageHandlerFnT(Protocol):
36+
async def __call__(
37+
self,
38+
message: RequestResponder[types.ServerRequest, types.ClientResult]
39+
| types.ServerNotification
40+
| Exception,
41+
) -> None: ...
42+
43+
44+
async def _default_message_handler(
45+
message: RequestResponder[types.ServerRequest, types.ClientResult]
46+
| types.ServerNotification
47+
| Exception,
48+
) -> None:
49+
await anyio.lowlevel.checkpoint()
50+
51+
3452
async def _default_sampling_callback(
3553
context: RequestContext["ClientSession", Any],
3654
params: types.CreateMessageRequestParams,
@@ -78,6 +96,7 @@ def __init__(
7896
sampling_callback: SamplingFnT | None = None,
7997
list_roots_callback: ListRootsFnT | None = None,
8098
logging_callback: LoggingFnT | None = None,
99+
message_handler: MessageHandlerFnT | None = None,
81100
) -> None:
82101
super().__init__(
83102
read_stream,
@@ -89,6 +108,7 @@ def __init__(
89108
self._sampling_callback = sampling_callback or _default_sampling_callback
90109
self._list_roots_callback = list_roots_callback or _default_list_roots_callback
91110
self._logging_callback = logging_callback or _default_logging_callback
111+
self._message_handler = message_handler or _default_message_handler
92112

93113
async def initialize(self) -> types.InitializeResult:
94114
sampling = types.SamplingCapability()
@@ -337,10 +357,20 @@ async def _received_request(
337357
types.ClientResult(root=types.EmptyResult())
338358
)
339359

360+
async def _handle_incoming(
361+
self,
362+
req: RequestResponder[types.ServerRequest, types.ClientResult]
363+
| types.ServerNotification
364+
| Exception,
365+
) -> None:
366+
"""Handle incoming messages by forwarding to the message handler."""
367+
await self._message_handler(req)
368+
340369
async def _received_notification(
341370
self, notification: types.ServerNotification
342371
) -> None:
343372
"""Handle notifications from the server."""
373+
# Process specific notification types
344374
match notification.root:
345375
case types.LoggingMessageNotification(params=params):
346376
await self._logging_callback(params)

src/mcp/server/session.py

+24
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,12 @@ class InitializationState(Enum):
6161

6262
ServerSessionT = TypeVar("ServerSessionT", bound="ServerSession")
6363

64+
ServerRequestResponder = (
65+
RequestResponder[types.ClientRequest, types.ServerResult]
66+
| types.ClientNotification
67+
| Exception
68+
)
69+
6470

6571
class ServerSession(
6672
BaseSession[
@@ -85,6 +91,15 @@ def __init__(
8591
)
8692
self._initialization_state = InitializationState.NotInitialized
8793
self._init_options = init_options
94+
self._incoming_message_stream_writer, self._incoming_message_stream_reader = (
95+
anyio.create_memory_object_stream[ServerRequestResponder](0)
96+
)
97+
self._exit_stack.push_async_callback(
98+
lambda: self._incoming_message_stream_reader.aclose()
99+
)
100+
self._exit_stack.push_async_callback(
101+
lambda: self._incoming_message_stream_writer.aclose()
102+
)
88103

89104
@property
90105
def client_params(self) -> types.InitializeRequestParams | None:
@@ -291,3 +306,12 @@ async def send_prompt_list_changed(self) -> None:
291306
)
292307
)
293308
)
309+
310+
async def _handle_incoming(self, req: ServerRequestResponder) -> None:
311+
await self._incoming_message_stream_writer.send(req)
312+
313+
@property
314+
def incoming_messages(
315+
self,
316+
) -> MemoryObjectReceiveStream[ServerRequestResponder]:
317+
return self._incoming_message_stream_reader

src/mcp/shared/memory.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,13 @@
1010
import anyio
1111
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
1212

13-
from mcp.client.session import ClientSession, ListRootsFnT, LoggingFnT, SamplingFnT
13+
from mcp.client.session import (
14+
ClientSession,
15+
ListRootsFnT,
16+
LoggingFnT,
17+
MessageHandlerFnT,
18+
SamplingFnT,
19+
)
1420
from mcp.server import Server
1521
from mcp.types import JSONRPCMessage
1622

@@ -58,6 +64,7 @@ async def create_connected_server_and_client_session(
5864
sampling_callback: SamplingFnT | None = None,
5965
list_roots_callback: ListRootsFnT | None = None,
6066
logging_callback: LoggingFnT | None = None,
67+
message_handler: MessageHandlerFnT | None = None,
6168
raise_exceptions: bool = False,
6269
) -> AsyncGenerator[ClientSession, None]:
6370
"""Creates a ClientSession that is connected to a running MCP server."""
@@ -87,6 +94,7 @@ async def create_connected_server_and_client_session(
8794
sampling_callback=sampling_callback,
8895
list_roots_callback=list_roots_callback,
8996
logging_callback=logging_callback,
97+
message_handler=message_handler,
9098
) as client_session:
9199
await client_session.initialize()
92100
yield client_session

src/mcp/shared/session.py

+11-27
Original file line numberDiff line numberDiff line change
@@ -189,19 +189,6 @@ def __init__(
189189
self._in_flight = {}
190190

191191
self._exit_stack = AsyncExitStack()
192-
self._incoming_message_stream_writer, self._incoming_message_stream_reader = (
193-
anyio.create_memory_object_stream[
194-
RequestResponder[ReceiveRequestT, SendResultT]
195-
| ReceiveNotificationT
196-
| Exception
197-
]()
198-
)
199-
self._exit_stack.push_async_callback(
200-
lambda: self._incoming_message_stream_reader.aclose()
201-
)
202-
self._exit_stack.push_async_callback(
203-
lambda: self._incoming_message_stream_writer.aclose()
204-
)
205192

206193
async def __aenter__(self) -> Self:
207194
self._task_group = anyio.create_task_group()
@@ -312,11 +299,10 @@ async def _receive_loop(self) -> None:
312299
async with (
313300
self._read_stream,
314301
self._write_stream,
315-
self._incoming_message_stream_writer,
316302
):
317303
async for message in self._read_stream:
318304
if isinstance(message, Exception):
319-
await self._incoming_message_stream_writer.send(message)
305+
await self._handle_incoming(message)
320306
elif isinstance(message.root, JSONRPCRequest):
321307
validated_request = self._receive_request_type.model_validate(
322308
message.root.model_dump(
@@ -336,8 +322,9 @@ async def _receive_loop(self) -> None:
336322

337323
self._in_flight[responder.request_id] = responder
338324
await self._received_request(responder)
325+
339326
if not responder._completed: # type: ignore[reportPrivateUsage]
340-
await self._incoming_message_stream_writer.send(responder)
327+
await self._handle_incoming(responder)
341328

342329
elif isinstance(message.root, JSONRPCNotification):
343330
try:
@@ -353,9 +340,7 @@ async def _receive_loop(self) -> None:
353340
await self._in_flight[cancelled_id].cancel()
354341
else:
355342
await self._received_notification(notification)
356-
await self._incoming_message_stream_writer.send(
357-
notification
358-
)
343+
await self._handle_incoming(notification)
359344
except Exception as e:
360345
# For other validation errors, log and continue
361346
logging.warning(
@@ -367,7 +352,7 @@ async def _receive_loop(self) -> None:
367352
if stream:
368353
await stream.send(message.root)
369354
else:
370-
await self._incoming_message_stream_writer.send(
355+
await self._handle_incoming(
371356
RuntimeError(
372357
"Received response with an unknown "
373358
f"request ID: {message}"
@@ -399,12 +384,11 @@ async def send_progress_notification(
399384
processed.
400385
"""
401386

402-
@property
403-
def incoming_messages(
387+
async def _handle_incoming(
404388
self,
405-
) -> MemoryObjectReceiveStream[
406-
RequestResponder[ReceiveRequestT, SendResultT]
389+
req: RequestResponder[ReceiveRequestT, SendResultT]
407390
| ReceiveNotificationT
408-
| Exception
409-
]:
410-
return self._incoming_message_stream_reader
391+
| Exception,
392+
) -> None:
393+
"""A generic handler for incoming messages. Overwritten by subclasses."""
394+
pass

tests/client/test_logging_callback.py

+34-36
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from typing import Literal
22

3-
import anyio
43
import pytest
54

5+
import mcp.types as types
66
from mcp.shared.memory import (
77
create_connected_server_and_client_session as create_session,
88
)
9+
from mcp.shared.session import RequestResponder
910
from mcp.types import (
1011
LoggingMessageNotificationParams,
1112
TextContent,
@@ -46,40 +47,37 @@ async def test_tool_with_log(
4647
)
4748
return True
4849

49-
async with anyio.create_task_group() as tg:
50-
async with create_session(
51-
server._mcp_server, logging_callback=logging_collector
52-
) as client_session:
50+
# Create a message handler to catch exceptions
51+
async def message_handler(
52+
message: RequestResponder[types.ServerRequest, types.ClientResult]
53+
| types.ServerNotification
54+
| Exception,
55+
) -> None:
56+
if isinstance(message, Exception):
57+
raise message
5358

54-
async def listen_session():
55-
try:
56-
async for message in client_session.incoming_messages:
57-
if isinstance(message, Exception):
58-
raise message
59-
except anyio.EndOfStream:
60-
pass
59+
async with create_session(
60+
server._mcp_server,
61+
logging_callback=logging_collector,
62+
message_handler=message_handler,
63+
) as client_session:
64+
# First verify our test tool works
65+
result = await client_session.call_tool("test_tool", {})
66+
assert result.isError is False
67+
assert isinstance(result.content[0], TextContent)
68+
assert result.content[0].text == "true"
6169

62-
tg.start_soon(listen_session)
63-
64-
# First verify our test tool works
65-
result = await client_session.call_tool("test_tool", {})
66-
assert result.isError is False
67-
assert isinstance(result.content[0], TextContent)
68-
assert result.content[0].text == "true"
69-
70-
# Now send a log message via our tool
71-
log_result = await client_session.call_tool(
72-
"test_tool_with_log",
73-
{
74-
"message": "Test log message",
75-
"level": "info",
76-
"logger": "test_logger",
77-
},
78-
)
79-
assert log_result.isError is False
80-
assert len(logging_collector.log_messages) == 1
81-
assert logging_collector.log_messages[
82-
0
83-
] == LoggingMessageNotificationParams(
84-
level="info", logger="test_logger", data="Test log message"
85-
)
70+
# Now send a log message via our tool
71+
log_result = await client_session.call_tool(
72+
"test_tool_with_log",
73+
{
74+
"message": "Test log message",
75+
"level": "info",
76+
"logger": "test_logger",
77+
},
78+
)
79+
assert log_result.isError is False
80+
assert len(logging_collector.log_messages) == 1
81+
assert logging_collector.log_messages[0] == LoggingMessageNotificationParams(
82+
level="info", logger="test_logger", data="Test log message"
83+
)

0 commit comments

Comments
 (0)