Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #201: Move incoming message stream from BaseSession to ServerSession #325

Merged
merged 3 commits into from
Mar 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 14 additions & 13 deletions src/mcp/client/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream

import mcp.types as types
from mcp.client.session import ClientSession
from mcp.client.sse import sse_client
from mcp.client.stdio import StdioServerParameters, stdio_client
from mcp.shared.session import RequestResponder
from mcp.types import JSONRPCMessage

if not sys.warnoptions:
Expand All @@ -21,26 +23,25 @@
logger = logging.getLogger("client")


async def receive_loop(session: ClientSession):
logger.info("Starting receive loop")
async for message in session.incoming_messages:
if isinstance(message, Exception):
logger.error("Error: %s", message)
continue
async def message_handler(
message: RequestResponder[types.ServerRequest, types.ClientResult]
| types.ServerNotification
| Exception,
) -> None:
if isinstance(message, Exception):
logger.error("Error: %s", message)
return

logger.info("Received message from server: %s", message)
logger.info("Received message from server: %s", message)


async def run_session(
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
write_stream: MemoryObjectSendStream[JSONRPCMessage],
):
async with (
ClientSession(read_stream, write_stream) as session,
anyio.create_task_group() as tg,
):
tg.start_soon(receive_loop, session)

async with ClientSession(
read_stream, write_stream, message_handler=message_handler
) as session:
logger.info("Initializing session")
await session.initialize()
logger.info("Initialized")
Expand Down
30 changes: 30 additions & 0 deletions src/mcp/client/session.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from datetime import timedelta
from typing import Any, Protocol

import anyio.lowlevel
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import AnyUrl, TypeAdapter

Expand Down Expand Up @@ -31,6 +32,23 @@ async def __call__(
) -> None: ...


class MessageHandlerFnT(Protocol):
async def __call__(
self,
message: RequestResponder[types.ServerRequest, types.ClientResult]
| types.ServerNotification
| Exception,
) -> None: ...


async def _default_message_handler(
message: RequestResponder[types.ServerRequest, types.ClientResult]
| types.ServerNotification
| Exception,
) -> None:
await anyio.lowlevel.checkpoint()


async def _default_sampling_callback(
context: RequestContext["ClientSession", Any],
params: types.CreateMessageRequestParams,
Expand Down Expand Up @@ -78,6 +96,7 @@ def __init__(
sampling_callback: SamplingFnT | None = None,
list_roots_callback: ListRootsFnT | None = None,
logging_callback: LoggingFnT | None = None,
message_handler: MessageHandlerFnT | None = None,
) -> None:
super().__init__(
read_stream,
Expand All @@ -89,6 +108,7 @@ def __init__(
self._sampling_callback = sampling_callback or _default_sampling_callback
self._list_roots_callback = list_roots_callback or _default_list_roots_callback
self._logging_callback = logging_callback or _default_logging_callback
self._message_handler = message_handler or _default_message_handler

async def initialize(self) -> types.InitializeResult:
sampling = types.SamplingCapability()
Expand Down Expand Up @@ -337,10 +357,20 @@ async def _received_request(
types.ClientResult(root=types.EmptyResult())
)

async def _handle_incoming(
self,
req: RequestResponder[types.ServerRequest, types.ClientResult]
| types.ServerNotification
| Exception,
) -> None:
"""Handle incoming messages by forwarding to the message handler."""
await self._message_handler(req)

async def _received_notification(
self, notification: types.ServerNotification
) -> None:
"""Handle notifications from the server."""
# Process specific notification types
match notification.root:
case types.LoggingMessageNotification(params=params):
await self._logging_callback(params)
Expand Down
24 changes: 24 additions & 0 deletions src/mcp/server/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ class InitializationState(Enum):

ServerSessionT = TypeVar("ServerSessionT", bound="ServerSession")

ServerRequestResponder = (
RequestResponder[types.ClientRequest, types.ServerResult]
| types.ClientNotification
| Exception
)


class ServerSession(
BaseSession[
Expand All @@ -85,6 +91,15 @@ def __init__(
)
self._initialization_state = InitializationState.NotInitialized
self._init_options = init_options
self._incoming_message_stream_writer, self._incoming_message_stream_reader = (
anyio.create_memory_object_stream[ServerRequestResponder](0)
)
self._exit_stack.push_async_callback(
lambda: self._incoming_message_stream_reader.aclose()
)
self._exit_stack.push_async_callback(
lambda: self._incoming_message_stream_writer.aclose()
)

@property
def client_params(self) -> types.InitializeRequestParams | None:
Expand Down Expand Up @@ -291,3 +306,12 @@ async def send_prompt_list_changed(self) -> None:
)
)
)

async def _handle_incoming(self, req: ServerRequestResponder) -> None:
await self._incoming_message_stream_writer.send(req)

@property
def incoming_messages(
self,
) -> MemoryObjectReceiveStream[ServerRequestResponder]:
return self._incoming_message_stream_reader
10 changes: 9 additions & 1 deletion src/mcp/shared/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,13 @@
import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream

from mcp.client.session import ClientSession, ListRootsFnT, LoggingFnT, SamplingFnT
from mcp.client.session import (
ClientSession,
ListRootsFnT,
LoggingFnT,
MessageHandlerFnT,
SamplingFnT,
)
from mcp.server import Server
from mcp.types import JSONRPCMessage

Expand Down Expand Up @@ -58,6 +64,7 @@ async def create_connected_server_and_client_session(
sampling_callback: SamplingFnT | None = None,
list_roots_callback: ListRootsFnT | None = None,
logging_callback: LoggingFnT | None = None,
message_handler: MessageHandlerFnT | None = None,
raise_exceptions: bool = False,
) -> AsyncGenerator[ClientSession, None]:
"""Creates a ClientSession that is connected to a running MCP server."""
Expand Down Expand Up @@ -87,6 +94,7 @@ async def create_connected_server_and_client_session(
sampling_callback=sampling_callback,
list_roots_callback=list_roots_callback,
logging_callback=logging_callback,
message_handler=message_handler,
) as client_session:
await client_session.initialize()
yield client_session
Expand Down
38 changes: 11 additions & 27 deletions src/mcp/shared/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,19 +189,6 @@ def __init__(
self._in_flight = {}

self._exit_stack = AsyncExitStack()
self._incoming_message_stream_writer, self._incoming_message_stream_reader = (
anyio.create_memory_object_stream[
RequestResponder[ReceiveRequestT, SendResultT]
| ReceiveNotificationT
| Exception
]()
)
self._exit_stack.push_async_callback(
lambda: self._incoming_message_stream_reader.aclose()
)
self._exit_stack.push_async_callback(
lambda: self._incoming_message_stream_writer.aclose()
)

async def __aenter__(self) -> Self:
self._task_group = anyio.create_task_group()
Expand Down Expand Up @@ -312,11 +299,10 @@ async def _receive_loop(self) -> None:
async with (
self._read_stream,
self._write_stream,
self._incoming_message_stream_writer,
):
async for message in self._read_stream:
if isinstance(message, Exception):
await self._incoming_message_stream_writer.send(message)
await self._handle_incoming(message)
elif isinstance(message.root, JSONRPCRequest):
validated_request = self._receive_request_type.model_validate(
message.root.model_dump(
Expand All @@ -336,8 +322,9 @@ async def _receive_loop(self) -> None:

self._in_flight[responder.request_id] = responder
await self._received_request(responder)

if not responder._completed: # type: ignore[reportPrivateUsage]
await self._incoming_message_stream_writer.send(responder)
await self._handle_incoming(responder)

elif isinstance(message.root, JSONRPCNotification):
try:
Expand All @@ -353,9 +340,7 @@ async def _receive_loop(self) -> None:
await self._in_flight[cancelled_id].cancel()
else:
await self._received_notification(notification)
await self._incoming_message_stream_writer.send(
notification
)
await self._handle_incoming(notification)
except Exception as e:
# For other validation errors, log and continue
logging.warning(
Expand All @@ -367,7 +352,7 @@ async def _receive_loop(self) -> None:
if stream:
await stream.send(message.root)
else:
await self._incoming_message_stream_writer.send(
await self._handle_incoming(
RuntimeError(
"Received response with an unknown "
f"request ID: {message}"
Expand Down Expand Up @@ -399,12 +384,11 @@ async def send_progress_notification(
processed.
"""

@property
def incoming_messages(
async def _handle_incoming(
self,
) -> MemoryObjectReceiveStream[
RequestResponder[ReceiveRequestT, SendResultT]
req: RequestResponder[ReceiveRequestT, SendResultT]
| ReceiveNotificationT
| Exception
]:
return self._incoming_message_stream_reader
| Exception,
) -> None:
"""A generic handler for incoming messages. Overwritten by subclasses."""
pass
70 changes: 34 additions & 36 deletions tests/client/test_logging_callback.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Literal

import anyio
import pytest

import mcp.types as types
from mcp.shared.memory import (
create_connected_server_and_client_session as create_session,
)
from mcp.shared.session import RequestResponder
from mcp.types import (
LoggingMessageNotificationParams,
TextContent,
Expand Down Expand Up @@ -46,40 +47,37 @@ async def test_tool_with_log(
)
return True

async with anyio.create_task_group() as tg:
async with create_session(
server._mcp_server, logging_callback=logging_collector
) as client_session:
# Create a message handler to catch exceptions
async def message_handler(
message: RequestResponder[types.ServerRequest, types.ClientResult]
| types.ServerNotification
| Exception,
) -> None:
if isinstance(message, Exception):
raise message

async def listen_session():
try:
async for message in client_session.incoming_messages:
if isinstance(message, Exception):
raise message
except anyio.EndOfStream:
pass
async with create_session(
server._mcp_server,
logging_callback=logging_collector,
message_handler=message_handler,
) as client_session:
# First verify our test tool works
result = await client_session.call_tool("test_tool", {})
assert result.isError is False
assert isinstance(result.content[0], TextContent)
assert result.content[0].text == "true"

tg.start_soon(listen_session)

# First verify our test tool works
result = await client_session.call_tool("test_tool", {})
assert result.isError is False
assert isinstance(result.content[0], TextContent)
assert result.content[0].text == "true"

# Now send a log message via our tool
log_result = await client_session.call_tool(
"test_tool_with_log",
{
"message": "Test log message",
"level": "info",
"logger": "test_logger",
},
)
assert log_result.isError is False
assert len(logging_collector.log_messages) == 1
assert logging_collector.log_messages[
0
] == LoggingMessageNotificationParams(
level="info", logger="test_logger", data="Test log message"
)
# Now send a log message via our tool
log_result = await client_session.call_tool(
"test_tool_with_log",
{
"message": "Test log message",
"level": "info",
"logger": "test_logger",
},
)
assert log_result.isError is False
assert len(logging_collector.log_messages) == 1
assert logging_collector.log_messages[0] == LoggingMessageNotificationParams(
level="info", logger="test_logger", data="Test log message"
)
Loading