Skip to content

Commit 85514b4

Browse files
committed
Move incoming message stream from BaseSession to ServerSession
Fixes GitHub issue #201 by moving the incoming message stream and related methods from BaseSession to ServerSession where they are actually needed. This change follows the principle of having functionality only where it's required. GitHub-Issue:#201 🤖 Generated with [Claude Code](https://claude.ai/code)
1 parent 5a54d82 commit 85514b4

File tree

2 files changed

+34
-27
lines changed

2 files changed

+34
-27
lines changed

src/mcp/server/session.py

+24
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,12 @@ class InitializationState(Enum):
6161

6262
ServerSessionT = TypeVar("ServerSessionT", bound="ServerSession")
6363

64+
ServerRequestResponder = (
65+
RequestResponder[types.ClientRequest, types.ServerResult]
66+
| types.ClientNotification
67+
| Exception
68+
)
69+
6470

6571
class ServerSession(
6672
BaseSession[
@@ -85,6 +91,15 @@ def __init__(
8591
)
8692
self._initialization_state = InitializationState.NotInitialized
8793
self._init_options = init_options
94+
self._incoming_message_stream_writer, self._incoming_message_stream_reader = (
95+
anyio.create_memory_object_stream[ServerRequestResponder]()
96+
)
97+
self._exit_stack.push_async_callback(
98+
lambda: self._incoming_message_stream_reader.aclose()
99+
)
100+
self._exit_stack.push_async_callback(
101+
lambda: self._incoming_message_stream_writer.aclose()
102+
)
88103

89104
@property
90105
def client_params(self) -> types.InitializeRequestParams | None:
@@ -291,3 +306,12 @@ async def send_prompt_list_changed(self) -> None:
291306
)
292307
)
293308
)
309+
310+
async def _handle_incoming(self, req: ServerRequestResponder) -> None:
311+
return await self._incoming_message_stream_writer.send(req)
312+
313+
@property
314+
def incoming_messages(
315+
self,
316+
) -> MemoryObjectReceiveStream[ServerRequestResponder]:
317+
return self._incoming_message_stream_reader

src/mcp/shared/session.py

+10-27
Original file line numberDiff line numberDiff line change
@@ -182,19 +182,6 @@ def __init__(
182182
self._in_flight = {}
183183

184184
self._exit_stack = AsyncExitStack()
185-
self._incoming_message_stream_writer, self._incoming_message_stream_reader = (
186-
anyio.create_memory_object_stream[
187-
RequestResponder[ReceiveRequestT, SendResultT]
188-
| ReceiveNotificationT
189-
| Exception
190-
]()
191-
)
192-
self._exit_stack.push_async_callback(
193-
lambda: self._incoming_message_stream_reader.aclose()
194-
)
195-
self._exit_stack.push_async_callback(
196-
lambda: self._incoming_message_stream_writer.aclose()
197-
)
198185

199186
async def __aenter__(self) -> Self:
200187
self._task_group = anyio.create_task_group()
@@ -300,11 +287,10 @@ async def _receive_loop(self) -> None:
300287
async with (
301288
self._read_stream,
302289
self._write_stream,
303-
self._incoming_message_stream_writer,
304290
):
305291
async for message in self._read_stream:
306292
if isinstance(message, Exception):
307-
await self._incoming_message_stream_writer.send(message)
293+
await self._handle_incoming(message)
308294
elif isinstance(message.root, JSONRPCRequest):
309295
validated_request = self._receive_request_type.model_validate(
310296
message.root.model_dump(
@@ -325,7 +311,7 @@ async def _receive_loop(self) -> None:
325311
self._in_flight[responder.request_id] = responder
326312
await self._received_request(responder)
327313
if not responder._completed:
328-
await self._incoming_message_stream_writer.send(responder)
314+
await self._handle_incoming(responder)
329315

330316
elif isinstance(message.root, JSONRPCNotification):
331317
try:
@@ -341,9 +327,7 @@ async def _receive_loop(self) -> None:
341327
await self._in_flight[cancelled_id].cancel()
342328
else:
343329
await self._received_notification(notification)
344-
await self._incoming_message_stream_writer.send(
345-
notification
346-
)
330+
await self._handle_incoming(notification)
347331
except Exception as e:
348332
# For other validation errors, log and continue
349333
logging.warning(
@@ -355,7 +339,7 @@ async def _receive_loop(self) -> None:
355339
if stream:
356340
await stream.send(message.root)
357341
else:
358-
await self._incoming_message_stream_writer.send(
342+
await self._handle_incoming(
359343
RuntimeError(
360344
"Received response with an unknown "
361345
f"request ID: {message}"
@@ -387,12 +371,11 @@ async def send_progress_notification(
387371
processed.
388372
"""
389373

390-
@property
391-
def incoming_messages(
374+
async def _handle_incoming(
392375
self,
393-
) -> MemoryObjectReceiveStream[
394-
RequestResponder[ReceiveRequestT, SendResultT]
376+
req: RequestResponder[ReceiveRequestT, SendResultT]
395377
| ReceiveNotificationT
396-
| Exception
397-
]:
398-
return self._incoming_message_stream_reader
378+
| Exception,
379+
) -> None:
380+
"""A generic handler for incoming messages. Overwritten by subclasses."""
381+
await anyio.lowlevel.checkpoint()

0 commit comments

Comments
 (0)