diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 66bf206e..cde3103b 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -1,11 +1,12 @@ from datetime import timedelta from typing import Any, Protocol +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import AnyUrl, TypeAdapter import mcp.types as types from mcp.shared.context import RequestContext -from mcp.shared.session import BaseSession, ReadStream, RequestResponder, WriteStream +from mcp.shared.session import BaseSession, RequestResponder from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS @@ -58,8 +59,8 @@ class ClientSession( ): def __init__( self, - read_stream: ReadStream, - write_stream: WriteStream, + read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception], + write_stream: MemoryObjectSendStream[types.JSONRPCMessage], read_timeout_seconds: timedelta | None = None, sampling_callback: SamplingFnT | None = None, list_roots_callback: ListRootsFnT | None = None, diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index 0f3039b5..abafacb9 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -6,16 +6,10 @@ import anyio import httpx from anyio.abc import TaskStatus +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from httpx_sse import aconnect_sse import mcp.types as types -from mcp.shared.session import ( - ReadStream, - ReadStreamWriter, - WriteStream, - WriteStreamReader, -) -from mcp.types import MessageFrame logger = logging.getLogger(__name__) @@ -37,11 +31,11 @@ async def sse_client( `sse_read_timeout` determines how long (in seconds) the client will wait for a new event before disconnecting. All other HTTP operations are controlled by `timeout`. """ - read_stream: ReadStream - read_stream_writer: ReadStreamWriter + read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] + read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] - write_stream: WriteStream - write_stream_reader: WriteStreamReader + write_stream: MemoryObjectSendStream[types.JSONRPCMessage] + write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage] read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) @@ -90,11 +84,8 @@ async def sse_reader( case "message": try: - message = MessageFrame( - message=types.JSONRPCMessage.model_validate_json( # noqa: E501 - sse.data - ), - raw=sse, + message = types.JSONRPCMessage.model_validate_json( # noqa: E501 + sse.data ) logger.debug( f"Received server message: {message}" diff --git a/src/mcp/client/websocket.py b/src/mcp/client/websocket.py index f2107d6b..3e73b020 100644 --- a/src/mcp/client/websocket.py +++ b/src/mcp/client/websocket.py @@ -1,7 +1,7 @@ import json import logging from contextlib import asynccontextmanager -from typing import Any, AsyncGenerator +from typing import AsyncGenerator import anyio from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream @@ -10,7 +10,6 @@ from websockets.typing import Subprotocol import mcp.types as types -from mcp.types import MessageFrame logger = logging.getLogger(__name__) @@ -20,8 +19,8 @@ async def websocket_client( url: str, ) -> AsyncGenerator[ tuple[ - MemoryObjectReceiveStream[MessageFrame[Any] | Exception], - MemoryObjectSendStream[MessageFrame[Any]], + MemoryObjectReceiveStream[types.JSONRPCMessage | Exception], + MemoryObjectSendStream[types.JSONRPCMessage], ], None, ]: @@ -54,11 +53,7 @@ async def ws_reader(): async with read_stream_writer: async for raw_text in ws: try: - json_message = types.JSONRPCMessage.model_validate_json( - raw_text - ) - # Create MessageFrame with JSON message as root - message = MessageFrame(message=json_message, raw=raw_text) + message = types.JSONRPCMessage.model_validate_json(raw_text) await read_stream_writer.send(message) except ValidationError as exc: # If JSON parse or model validation fails, send the exception @@ -71,8 +66,8 @@ async def ws_writer(): """ async with write_stream_reader: async for message in write_stream_reader: - # Extract the JSON-RPC message from MessageFrame and convert to JSON - msg_dict = message.message.model_dump( + # Convert to a dict, then to JSON + msg_dict = message.model_dump( by_alias=True, mode="json", exclude_none=True ) await ws.send(json.dumps(msg_dict)) diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 7ceb103e..817d1918 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -74,6 +74,7 @@ async def main(): from typing import Any, AsyncIterator, Generic, TypeVar import anyio +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import AnyUrl import mcp.types as types @@ -83,7 +84,7 @@ async def main(): from mcp.server.stdio import stdio_server as stdio_server from mcp.shared.context import RequestContext from mcp.shared.exceptions import McpError -from mcp.shared.session import ReadStream, RequestResponder, WriteStream +from mcp.shared.session import RequestResponder logger = logging.getLogger(__name__) @@ -473,8 +474,8 @@ async def handler(req: types.CompleteRequest): async def run( self, - read_stream: ReadStream, - write_stream: WriteStream, + read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception], + write_stream: MemoryObjectSendStream[types.JSONRPCMessage], initialization_options: InitializationOptions, # When False, exceptions are returned as messages to the client. # When True, exceptions are raised, which will cause the server to shut down diff --git a/src/mcp/server/models.py b/src/mcp/server/models.py index 58a2db1d..3b5abba7 100644 --- a/src/mcp/server/models.py +++ b/src/mcp/server/models.py @@ -5,7 +5,9 @@ from pydantic import BaseModel -from mcp.types import ServerCapabilities +from mcp.types import ( + ServerCapabilities, +) class InitializationOptions(BaseModel): diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index c22dcf87..788bb9f8 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -42,15 +42,14 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]: import anyio import anyio.lowlevel +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import AnyUrl import mcp.types as types from mcp.server.models import InitializationOptions from mcp.shared.session import ( BaseSession, - ReadStream, RequestResponder, - WriteStream, ) @@ -77,8 +76,8 @@ class ServerSession( def __init__( self, - read_stream: ReadStream, - write_stream: WriteStream, + read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception], + write_stream: MemoryObjectSendStream[types.JSONRPCMessage], init_options: InitializationOptions, ) -> None: super().__init__( diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index 1e869685..0127753d 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -38,6 +38,7 @@ async def handle_sse(request): from uuid import UUID, uuid4 import anyio +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import ValidationError from sse_starlette import EventSourceResponse from starlette.requests import Request @@ -45,13 +46,6 @@ async def handle_sse(request): from starlette.types import Receive, Scope, Send import mcp.types as types -from mcp.shared.session import ( - ReadStream, - ReadStreamWriter, - WriteStream, - WriteStreamReader, -) -from mcp.types import MessageFrame logger = logging.getLogger(__name__) @@ -69,7 +63,9 @@ class SseServerTransport: """ _endpoint: str - _read_stream_writers: dict[UUID, ReadStreamWriter] + _read_stream_writers: dict[ + UUID, MemoryObjectSendStream[types.JSONRPCMessage | Exception] + ] def __init__(self, endpoint: str) -> None: """ @@ -89,11 +85,11 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send): raise ValueError("connect_sse can only handle HTTP requests") logger.debug("Setting up SSE connection") - read_stream: ReadStream - read_stream_writer: ReadStreamWriter + read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] + read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] - write_stream: WriteStream - write_stream_reader: WriteStreamReader + write_stream: MemoryObjectSendStream[types.JSONRPCMessage] + write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage] read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) @@ -176,4 +172,4 @@ async def handle_post_message( logger.debug(f"Sending message to writer: {message}") response = Response("Accepted", status_code=202) await response(scope, receive, send) - await writer.send(MessageFrame(message=message, raw=request)) + await writer.send(message) diff --git a/src/mcp/server/stdio.py b/src/mcp/server/stdio.py index 91819a7d..0e0e4912 100644 --- a/src/mcp/server/stdio.py +++ b/src/mcp/server/stdio.py @@ -24,15 +24,9 @@ async def run_server(): import anyio import anyio.lowlevel +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream import mcp.types as types -from mcp.shared.session import ( - ReadStream, - ReadStreamWriter, - WriteStream, - WriteStreamReader, -) -from mcp.types import MessageFrame @asynccontextmanager @@ -53,11 +47,11 @@ async def stdio_server( if not stdout: stdout = anyio.wrap_file(TextIOWrapper(sys.stdout.buffer, encoding="utf-8")) - read_stream: ReadStream - read_stream_writer: ReadStreamWriter + read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] + read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] - write_stream: WriteStream - write_stream_reader: WriteStreamReader + write_stream: MemoryObjectSendStream[types.JSONRPCMessage] + write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage] read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) @@ -72,9 +66,7 @@ async def stdin_reader(): await read_stream_writer.send(exc) continue - await read_stream_writer.send( - MessageFrame(message=message, raw=line) - ) + await read_stream_writer.send(message) except anyio.ClosedResourceError: await anyio.lowlevel.checkpoint() @@ -82,7 +74,6 @@ async def stdout_writer(): try: async with write_stream_reader: async for message in write_stream_reader: - # Extract the inner JSONRPCRequest/JSONRPCResponse from MessageFrame json = message.model_dump_json(by_alias=True, exclude_none=True) await stdout.write(json + "\n") await stdout.flush() diff --git a/src/mcp/server/websocket.py b/src/mcp/server/websocket.py index 2da93634..bd3d632e 100644 --- a/src/mcp/server/websocket.py +++ b/src/mcp/server/websocket.py @@ -2,17 +2,11 @@ from contextlib import asynccontextmanager import anyio +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from starlette.types import Receive, Scope, Send from starlette.websockets import WebSocket import mcp.types as types -from mcp.shared.session import ( - ReadStream, - ReadStreamWriter, - WriteStream, - WriteStreamReader, -) -from mcp.types import MessageFrame logger = logging.getLogger(__name__) @@ -27,11 +21,11 @@ async def websocket_server(scope: Scope, receive: Receive, send: Send): websocket = WebSocket(scope, receive, send) await websocket.accept(subprotocol="mcp") - read_stream: ReadStream - read_stream_writer: ReadStreamWriter + read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] + read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] - write_stream: WriteStream - write_stream_reader: WriteStreamReader + write_stream: MemoryObjectSendStream[types.JSONRPCMessage] + write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage] read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) @@ -46,9 +40,7 @@ async def ws_reader(): await read_stream_writer.send(exc) continue - await read_stream_writer.send( - MessageFrame(message=client_message, raw=message) - ) + await read_stream_writer.send(client_message) except anyio.ClosedResourceError: await websocket.close() diff --git a/src/mcp/shared/memory.py b/src/mcp/shared/memory.py index 762ff28a..ae6b0be5 100644 --- a/src/mcp/shared/memory.py +++ b/src/mcp/shared/memory.py @@ -11,11 +11,11 @@ from mcp.client.session import ClientSession, ListRootsFnT, SamplingFnT from mcp.server import Server -from mcp.types import MessageFrame +from mcp.types import JSONRPCMessage MessageStream = tuple[ - MemoryObjectReceiveStream[MessageFrame | Exception], - MemoryObjectSendStream[MessageFrame], + MemoryObjectReceiveStream[JSONRPCMessage | Exception], + MemoryObjectSendStream[JSONRPCMessage], ] @@ -32,10 +32,10 @@ async def create_client_server_memory_streams() -> ( """ # Create streams for both directions server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ - MessageFrame | Exception + JSONRPCMessage | Exception ](1) client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ - MessageFrame | Exception + JSONRPCMessage | Exception ](1) client_streams = (server_to_client_receive, client_to_server_send) @@ -60,9 +60,12 @@ async def create_connected_server_and_client_session( ) -> AsyncGenerator[ClientSession, None]: """Creates a ClientSession that is connected to a running MCP server.""" async with create_client_server_memory_streams() as ( - (client_read, client_write), - (server_read, server_write), + client_streams, + server_streams, ): + client_read, client_write = client_streams + server_read, server_write = server_streams + # Create a cancel scope for the server task async with anyio.create_task_group() as tg: tg.start_soon( diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 7dd6fefc..31f88824 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -22,18 +22,12 @@ JSONRPCNotification, JSONRPCRequest, JSONRPCResponse, - MessageFrame, RequestParams, ServerNotification, ServerRequest, ServerResult, ) -ReadStream = MemoryObjectReceiveStream[MessageFrame | Exception] -ReadStreamWriter = MemoryObjectSendStream[MessageFrame | Exception] -WriteStream = MemoryObjectSendStream[MessageFrame] -WriteStreamReader = MemoryObjectReceiveStream[MessageFrame] - SendRequestT = TypeVar("SendRequestT", ClientRequest, ServerRequest) SendResultT = TypeVar("SendResultT", ClientResult, ServerResult) SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotification) @@ -171,8 +165,8 @@ class BaseSession( def __init__( self, - read_stream: ReadStream, - write_stream: WriteStream, + read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception], + write_stream: MemoryObjectSendStream[JSONRPCMessage], receive_request_type: type[ReceiveRequestT], receive_notification_type: type[ReceiveNotificationT], # If none, reading will never time out @@ -248,9 +242,7 @@ async def send_request( # TODO: Support progress callbacks - await self._write_stream.send( - MessageFrame(message=JSONRPCMessage(jsonrpc_request), raw=None) - ) + await self._write_stream.send(JSONRPCMessage(jsonrpc_request)) try: with anyio.fail_after( @@ -286,18 +278,14 @@ async def send_notification(self, notification: SendNotificationT) -> None: **notification.model_dump(by_alias=True, mode="json", exclude_none=True), ) - await self._write_stream.send( - MessageFrame(message=JSONRPCMessage(jsonrpc_notification), raw=None) - ) + await self._write_stream.send(JSONRPCMessage(jsonrpc_notification)) async def _send_response( self, request_id: RequestId, response: SendResultT | ErrorData ) -> None: if isinstance(response, ErrorData): jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response) - await self._write_stream.send( - MessageFrame(message=JSONRPCMessage(jsonrpc_error), raw=None) - ) + await self._write_stream.send(JSONRPCMessage(jsonrpc_error)) else: jsonrpc_response = JSONRPCResponse( jsonrpc="2.0", @@ -306,9 +294,7 @@ async def _send_response( by_alias=True, mode="json", exclude_none=True ), ) - await self._write_stream.send( - MessageFrame(message=JSONRPCMessage(jsonrpc_response), raw=None) - ) + await self._write_stream.send(JSONRPCMessage(jsonrpc_response)) async def _receive_loop(self) -> None: async with ( @@ -316,13 +302,10 @@ async def _receive_loop(self) -> None: self._write_stream, self._incoming_message_stream_writer, ): - async for raw_message in self._read_stream: - if isinstance(raw_message, Exception): - await self._incoming_message_stream_writer.send(raw_message) - continue - - message = raw_message.message - if isinstance(message.root, JSONRPCRequest): + async for message in self._read_stream: + if isinstance(message, Exception): + await self._incoming_message_stream_writer.send(message) + elif isinstance(message.root, JSONRPCRequest): validated_request = self._receive_request_type.model_validate( message.root.model_dump( by_alias=True, mode="json", exclude_none=True diff --git a/src/mcp/types.py b/src/mcp/types.py index 38384dea..7d867bd3 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -180,49 +180,6 @@ class JSONRPCMessage( pass -RawT = TypeVar("RawT") - - -class MessageFrame(BaseModel, Generic[RawT]): - """ - A wrapper around the general message received that contains both the parsed message - and the raw message. - - This class serves as an encapsulation for JSON-RPC messages, providing access to - both the parsed structure (root) and the original raw data. This design is - particularly useful for Server-Sent Events (SSE) consumers who may need to access - additional metadata or headers associated with the message. - - The 'root' attribute contains the parsed JSONRPCMessage, which could be a request, - notification, response, or error. The 'raw' attribute preserves the original - message as received, allowing access to any additional context or metadata that - might be lost in parsing. - - This dual representation allows for flexible handling of messages, where consumers - can work with the structured data for standard operations, but still have the - option to examine or utilize the raw data when needed, such as for debugging, - logging, or accessing transport-specific information. - """ - - message: JSONRPCMessage - raw: RawT | None = None - model_config = ConfigDict(extra="allow") - - def model_dump(self, *args, **kwargs): - """ - Dumps the model to a dictionary, delegating to the root JSONRPCMessage. - This method allows for consistent serialization of the parsed message. - """ - return self.message.model_dump(*args, **kwargs) - - def model_dump_json(self, *args, **kwargs): - """ - Dumps the model to a JSON string, delegating to the root JSONRPCMessage. - This method provides a convenient way to serialize the parsed message to JSON. - """ - return self.message.model_dump_json(*args, **kwargs) - - class EmptyResult(Result): """A response that indicates success but carries no data.""" diff --git a/tests/client/test_session.py b/tests/client/test_session.py index 27f02abf..7d579cda 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -1,5 +1,3 @@ -from types import NoneType - import anyio import pytest @@ -13,9 +11,9 @@ InitializeRequest, InitializeResult, JSONRPCMessage, + JSONRPCNotification, JSONRPCRequest, JSONRPCResponse, - MessageFrame, ServerCapabilities, ServerResult, ) @@ -24,10 +22,10 @@ @pytest.mark.anyio async def test_client_session_initialize(): client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ - MessageFrame[NoneType] + JSONRPCMessage ](1) server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ - MessageFrame[NoneType] + JSONRPCMessage ](1) initialized_notification = None @@ -36,7 +34,7 @@ async def mock_server(): nonlocal initialized_notification jsonrpc_request = await client_to_server_receive.receive() - assert isinstance(jsonrpc_request, MessageFrame) + assert isinstance(jsonrpc_request.root, JSONRPCRequest) request = ClientRequest.model_validate( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) ) @@ -58,25 +56,21 @@ async def mock_server(): ) async with server_to_client_send: - assert isinstance(jsonrpc_request.message.root, JSONRPCRequest) await server_to_client_send.send( - MessageFrame( - message=JSONRPCMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=jsonrpc_request.message.root.id, - result=result.model_dump( - by_alias=True, mode="json", exclude_none=True - ), - ) - ), - raw=None, + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=result.model_dump( + by_alias=True, mode="json", exclude_none=True + ), + ) ) ) jsonrpc_notification = await client_to_server_receive.receive() - assert isinstance(jsonrpc_notification.message, JSONRPCMessage) + assert isinstance(jsonrpc_notification.root, JSONRPCNotification) initialized_notification = ClientNotification.model_validate( - jsonrpc_notification.message.model_dump( + jsonrpc_notification.model_dump( by_alias=True, mode="json", exclude_none=True ) ) diff --git a/tests/issues/test_192_request_id.py b/tests/issues/test_192_request_id.py index fd05c773..00e18789 100644 --- a/tests/issues/test_192_request_id.py +++ b/tests/issues/test_192_request_id.py @@ -11,7 +11,6 @@ JSONRPCMessage, JSONRPCNotification, JSONRPCRequest, - MessageFrame, NotificationParams, ) @@ -65,9 +64,7 @@ async def run_server(): jsonrpc="2.0", ) - await client_writer.send( - MessageFrame(message=JSONRPCMessage(root=init_req), raw=None) - ) + await client_writer.send(JSONRPCMessage(root=init_req)) await server_reader.receive() # Get init response but don't need to check it # Send initialized notification @@ -76,27 +73,21 @@ async def run_server(): params=NotificationParams().model_dump(by_alias=True, exclude_none=True), jsonrpc="2.0", ) - await client_writer.send( - MessageFrame( - message=JSONRPCMessage(root=initialized_notification), raw=None - ) - ) + await client_writer.send(JSONRPCMessage(root=initialized_notification)) # Send ping request with custom ID ping_request = JSONRPCRequest( id=custom_request_id, method="ping", params={}, jsonrpc="2.0" ) - await client_writer.send( - MessageFrame(message=JSONRPCMessage(root=ping_request), raw=None) - ) + await client_writer.send(JSONRPCMessage(root=ping_request)) # Read response response = await server_reader.receive() # Verify response ID matches request ID assert ( - response.message.root.id == custom_request_id + response.root.id == custom_request_id ), "Response ID should match request ID" # Cancel server task diff --git a/tests/server/test_lifespan.py b/tests/server/test_lifespan.py index 18d9a4c5..37a52969 100644 --- a/tests/server/test_lifespan.py +++ b/tests/server/test_lifespan.py @@ -17,7 +17,6 @@ JSONRPCMessage, JSONRPCNotification, JSONRPCRequest, - MessageFrame, ) @@ -65,7 +64,7 @@ async def run_server(): send_stream2, InitializationOptions( server_name="test", - server_version="1.0.0", + server_version="0.1.0", capabilities=server.get_capabilities( notification_options=NotificationOptions(), experimental_capabilities={}, @@ -83,51 +82,42 @@ async def run_server(): clientInfo=Implementation(name="test-client", version="0.1.0"), ) await send_stream1.send( - MessageFrame( - message=JSONRPCMessage( - root=JSONRPCRequest( - jsonrpc="2.0", - id=1, - method="initialize", - params=TypeAdapter(InitializeRequestParams).dump_python(params), - ) - ), - raw=None, + JSONRPCMessage( + root=JSONRPCRequest( + jsonrpc="2.0", + id=1, + method="initialize", + params=TypeAdapter(InitializeRequestParams).dump_python(params), + ) ) ) response = await receive_stream2.receive() # Send initialized notification await send_stream1.send( - MessageFrame( - message=JSONRPCMessage( - root=JSONRPCNotification( - jsonrpc="2.0", - method="notifications/initialized", - ) - ), - raw=None, + JSONRPCMessage( + root=JSONRPCNotification( + jsonrpc="2.0", + method="notifications/initialized", + ) ) ) # Call the tool to verify lifespan context await send_stream1.send( - MessageFrame( - message=JSONRPCMessage( - root=JSONRPCRequest( - jsonrpc="2.0", - id=2, - method="tools/call", - params={"name": "check_lifespan", "arguments": {}}, - ) - ), - raw=None, + JSONRPCMessage( + root=JSONRPCRequest( + jsonrpc="2.0", + id=2, + method="tools/call", + params={"name": "check_lifespan", "arguments": {}}, + ) ) ) # Get response and verify response = await receive_stream2.receive() - assert response.message.root.result["content"][0]["text"] == "true" + assert response.root.result["content"][0]["text"] == "true" # Cancel server task tg.cancel_scope.cancel() @@ -188,51 +178,42 @@ async def run_server(): clientInfo=Implementation(name="test-client", version="0.1.0"), ) await send_stream1.send( - MessageFrame( - message=JSONRPCMessage( - root=JSONRPCRequest( - jsonrpc="2.0", - id=1, - method="initialize", - params=TypeAdapter(InitializeRequestParams).dump_python(params), - ) - ), - raw=None, + JSONRPCMessage( + root=JSONRPCRequest( + jsonrpc="2.0", + id=1, + method="initialize", + params=TypeAdapter(InitializeRequestParams).dump_python(params), + ) ) ) response = await receive_stream2.receive() # Send initialized notification await send_stream1.send( - MessageFrame( - message=JSONRPCMessage( - root=JSONRPCNotification( - jsonrpc="2.0", - method="notifications/initialized", - ) - ), - raw=None, + JSONRPCMessage( + root=JSONRPCNotification( + jsonrpc="2.0", + method="notifications/initialized", + ) ) ) # Call the tool to verify lifespan context await send_stream1.send( - MessageFrame( - message=JSONRPCMessage( - root=JSONRPCRequest( - jsonrpc="2.0", - id=2, - method="tools/call", - params={"name": "check_lifespan", "arguments": {}}, - ) - ), - raw=None, + JSONRPCMessage( + root=JSONRPCRequest( + jsonrpc="2.0", + id=2, + method="tools/call", + params={"name": "check_lifespan", "arguments": {}}, + ) ) ) # Get response and verify response = await receive_stream2.receive() - assert response.message.root.result["content"][0]["text"] == "true" + assert response.root.result["content"][0]["text"] == "true" # Cancel server task tg.cancel_scope.cancel() diff --git a/tests/server/test_session.py b/tests/server/test_session.py index a28fda7f..333196c9 100644 --- a/tests/server/test_session.py +++ b/tests/server/test_session.py @@ -9,7 +9,7 @@ from mcp.types import ( ClientNotification, InitializedNotification, - MessageFrame, + JSONRPCMessage, PromptsCapability, ResourcesCapability, ServerCapabilities, @@ -19,10 +19,10 @@ @pytest.mark.anyio async def test_server_session_initialize(): server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ - MessageFrame[None] + JSONRPCMessage ](1) client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ - MessageFrame[None] + JSONRPCMessage ](1) async def run_client(client: ClientSession): diff --git a/tests/server/test_stdio.py b/tests/server/test_stdio.py index c12c2637..85c5bf21 100644 --- a/tests/server/test_stdio.py +++ b/tests/server/test_stdio.py @@ -4,7 +4,7 @@ import pytest from mcp.server.stdio import stdio_server -from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse, MessageFrame +from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse @pytest.mark.anyio @@ -13,8 +13,8 @@ async def test_stdio_server(): stdout = io.StringIO() messages = [ - JSONRPCRequest(jsonrpc="2.0", id=1, method="ping"), - JSONRPCResponse(jsonrpc="2.0", id=2, result={}), + JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")), + JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=2, result={})), ] for message in messages: @@ -35,29 +35,17 @@ async def test_stdio_server(): # Verify received messages assert len(received_messages) == 2 - assert isinstance(received_messages[0].message, JSONRPCMessage) - assert isinstance(received_messages[0].message.root, JSONRPCRequest) - assert received_messages[0].message.root.id == 1 - assert received_messages[0].message.root.method == "ping" - - assert isinstance(received_messages[1].message, JSONRPCMessage) - assert isinstance(received_messages[1].message.root, JSONRPCResponse) - assert received_messages[1].message.root.id == 2 + assert received_messages[0] == JSONRPCMessage( + root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") + ) + assert received_messages[1] == JSONRPCMessage( + root=JSONRPCResponse(jsonrpc="2.0", id=2, result={}) + ) # Test sending responses from the server responses = [ - MessageFrame( - message=JSONRPCMessage( - root=JSONRPCRequest(jsonrpc="2.0", id=3, method="ping") - ), - raw=None, - ), - MessageFrame( - message=JSONRPCMessage( - root=JSONRPCResponse(jsonrpc="2.0", id=4, result={}) - ), - raw=None, - ), + JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=3, method="ping")), + JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=4, result={})), ] async with write_stream: @@ -68,10 +56,13 @@ async def test_stdio_server(): output_lines = stdout.readlines() assert len(output_lines) == 2 - # Parse and verify the JSON responses directly - request_json = JSONRPCRequest.model_validate_json(output_lines[0].strip()) - response_json = JSONRPCResponse.model_validate_json(output_lines[1].strip()) - - assert request_json.id == 3 - assert request_json.method == "ping" - assert response_json.id == 4 + received_responses = [ + JSONRPCMessage.model_validate_json(line.strip()) for line in output_lines + ] + assert len(received_responses) == 2 + assert received_responses[0] == JSONRPCMessage( + root=JSONRPCRequest(jsonrpc="2.0", id=3, method="ping") + ) + assert received_responses[1] == JSONRPCMessage( + root=JSONRPCResponse(jsonrpc="2.0", id=4, result={}) + )