Skip to content

Commit 9d0f2da

Browse files
authored
refactor: reorganize message handling for better type safety and clarity (#239)
* refactor: improve typing with memory stream type aliases Move memory stream type definitions to models.py and use them throughout the codebase for better type safety and maintainability. GitHub-Issue:#201 * refactor: move streams to ParsedMessage * refactor: update test files to use ParsedMessage Updates test files to work with the ParsedMessage stream type aliases and fixes a line length issue in test_201_client_hangs_on_logging.py. Github-Issue:#201 * refactor: rename ParsedMessage to MessageFrame for clarity 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]> * refactor: move MessageFrame class to types.py for better code organization 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]> * fix pyright * refactor: update websocket client to use MessageFrame Modified the websocket client to work with the new MessageFrame type, preserving raw message text and properly extracting the root JSON-RPC message when sending. Github-Issue:#204 * fix: use NoneType instead of None for type parameters in MessageFrame 🤖 Generated with [Claude Code](https://claude.ai/code) * refactor: rename root to message
1 parent ad7f7a5 commit 9d0f2da

17 files changed

+283
-151
lines changed

src/mcp/client/session.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
from datetime import timedelta
22
from typing import Any, Protocol
33

4-
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
54
from pydantic import AnyUrl, TypeAdapter
65

76
import mcp.types as types
87
from mcp.shared.context import RequestContext
9-
from mcp.shared.session import BaseSession, RequestResponder
8+
from mcp.shared.session import BaseSession, ReadStream, RequestResponder, WriteStream
109
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
1110

1211

@@ -59,8 +58,8 @@ class ClientSession(
5958
):
6059
def __init__(
6160
self,
62-
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
63-
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
61+
read_stream: ReadStream,
62+
write_stream: WriteStream,
6463
read_timeout_seconds: timedelta | None = None,
6564
sampling_callback: SamplingFnT | None = None,
6665
list_roots_callback: ListRootsFnT | None = None,

src/mcp/client/sse.py

+16-7
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,16 @@
66
import anyio
77
import httpx
88
from anyio.abc import TaskStatus
9-
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
109
from httpx_sse import aconnect_sse
1110

1211
import mcp.types as types
12+
from mcp.shared.session import (
13+
ReadStream,
14+
ReadStreamWriter,
15+
WriteStream,
16+
WriteStreamReader,
17+
)
18+
from mcp.types import MessageFrame
1319

1420
logger = logging.getLogger(__name__)
1521

@@ -31,11 +37,11 @@ async def sse_client(
3137
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
3238
event before disconnecting. All other HTTP operations are controlled by `timeout`.
3339
"""
34-
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
35-
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
40+
read_stream: ReadStream
41+
read_stream_writer: ReadStreamWriter
3642

37-
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
38-
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
43+
write_stream: WriteStream
44+
write_stream_reader: WriteStreamReader
3945

4046
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
4147
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
@@ -84,8 +90,11 @@ async def sse_reader(
8490

8591
case "message":
8692
try:
87-
message = types.JSONRPCMessage.model_validate_json( # noqa: E501
88-
sse.data
93+
message = MessageFrame(
94+
message=types.JSONRPCMessage.model_validate_json( # noqa: E501
95+
sse.data
96+
),
97+
raw=sse,
8998
)
9099
logger.debug(
91100
f"Received server message: {message}"

src/mcp/client/websocket.py

+11-6
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import json
22
import logging
33
from contextlib import asynccontextmanager
4-
from typing import AsyncGenerator
4+
from typing import Any, AsyncGenerator
55

66
import anyio
77
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
@@ -10,6 +10,7 @@
1010
from websockets.typing import Subprotocol
1111

1212
import mcp.types as types
13+
from mcp.types import MessageFrame
1314

1415
logger = logging.getLogger(__name__)
1516

@@ -19,8 +20,8 @@ async def websocket_client(
1920
url: str,
2021
) -> AsyncGenerator[
2122
tuple[
22-
MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
23-
MemoryObjectSendStream[types.JSONRPCMessage],
23+
MemoryObjectReceiveStream[MessageFrame[Any] | Exception],
24+
MemoryObjectSendStream[MessageFrame[Any]],
2425
],
2526
None,
2627
]:
@@ -53,7 +54,11 @@ async def ws_reader():
5354
async with read_stream_writer:
5455
async for raw_text in ws:
5556
try:
56-
message = types.JSONRPCMessage.model_validate_json(raw_text)
57+
json_message = types.JSONRPCMessage.model_validate_json(
58+
raw_text
59+
)
60+
# Create MessageFrame with JSON message as root
61+
message = MessageFrame(message=json_message, raw=raw_text)
5762
await read_stream_writer.send(message)
5863
except ValidationError as exc:
5964
# If JSON parse or model validation fails, send the exception
@@ -66,8 +71,8 @@ async def ws_writer():
6671
"""
6772
async with write_stream_reader:
6873
async for message in write_stream_reader:
69-
# Convert to a dict, then to JSON
70-
msg_dict = message.model_dump(
74+
# Extract the JSON-RPC message from MessageFrame and convert to JSON
75+
msg_dict = message.message.model_dump(
7176
by_alias=True, mode="json", exclude_none=True
7277
)
7378
await ws.send(json.dumps(msg_dict))

src/mcp/server/lowlevel/server.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@ async def main():
7474
from typing import Any, AsyncIterator, Generic, TypeVar
7575

7676
import anyio
77-
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
7877
from pydantic import AnyUrl
7978

8079
import mcp.types as types
@@ -84,7 +83,7 @@ async def main():
8483
from mcp.server.stdio import stdio_server as stdio_server
8584
from mcp.shared.context import RequestContext
8685
from mcp.shared.exceptions import McpError
87-
from mcp.shared.session import RequestResponder
86+
from mcp.shared.session import ReadStream, RequestResponder, WriteStream
8887

8988
logger = logging.getLogger(__name__)
9089

@@ -474,8 +473,8 @@ async def handler(req: types.CompleteRequest):
474473

475474
async def run(
476475
self,
477-
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
478-
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
476+
read_stream: ReadStream,
477+
write_stream: WriteStream,
479478
initialization_options: InitializationOptions,
480479
# When False, exceptions are returned as messages to the client.
481480
# When True, exceptions are raised, which will cause the server to shut down

src/mcp/server/models.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,7 @@
55

66
from pydantic import BaseModel
77

8-
from mcp.types import (
9-
ServerCapabilities,
10-
)
8+
from mcp.types import ServerCapabilities
119

1210

1311
class InitializationOptions(BaseModel):

src/mcp/server/session.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,15 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]:
4242

4343
import anyio
4444
import anyio.lowlevel
45-
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
4645
from pydantic import AnyUrl
4746

4847
import mcp.types as types
4948
from mcp.server.models import InitializationOptions
5049
from mcp.shared.session import (
5150
BaseSession,
51+
ReadStream,
5252
RequestResponder,
53+
WriteStream,
5354
)
5455

5556

@@ -76,8 +77,8 @@ class ServerSession(
7677

7778
def __init__(
7879
self,
79-
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
80-
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
80+
read_stream: ReadStream,
81+
write_stream: WriteStream,
8182
init_options: InitializationOptions,
8283
) -> None:
8384
super().__init__(

src/mcp/server/sse.py

+13-9
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,20 @@ async def handle_sse(request):
3838
from uuid import UUID, uuid4
3939

4040
import anyio
41-
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
4241
from pydantic import ValidationError
4342
from sse_starlette import EventSourceResponse
4443
from starlette.requests import Request
4544
from starlette.responses import Response
4645
from starlette.types import Receive, Scope, Send
4746

4847
import mcp.types as types
48+
from mcp.shared.session import (
49+
ReadStream,
50+
ReadStreamWriter,
51+
WriteStream,
52+
WriteStreamReader,
53+
)
54+
from mcp.types import MessageFrame
4955

5056
logger = logging.getLogger(__name__)
5157

@@ -63,9 +69,7 @@ class SseServerTransport:
6369
"""
6470

6571
_endpoint: str
66-
_read_stream_writers: dict[
67-
UUID, MemoryObjectSendStream[types.JSONRPCMessage | Exception]
68-
]
72+
_read_stream_writers: dict[UUID, ReadStreamWriter]
6973

7074
def __init__(self, endpoint: str) -> None:
7175
"""
@@ -85,11 +89,11 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send):
8589
raise ValueError("connect_sse can only handle HTTP requests")
8690

8791
logger.debug("Setting up SSE connection")
88-
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
89-
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
92+
read_stream: ReadStream
93+
read_stream_writer: ReadStreamWriter
9094

91-
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
92-
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
95+
write_stream: WriteStream
96+
write_stream_reader: WriteStreamReader
9397

9498
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
9599
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
@@ -172,4 +176,4 @@ async def handle_post_message(
172176
logger.debug(f"Sending message to writer: {message}")
173177
response = Response("Accepted", status_code=202)
174178
await response(scope, receive, send)
175-
await writer.send(message)
179+
await writer.send(MessageFrame(message=message, raw=request))

src/mcp/server/stdio.py

+15-6
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,15 @@ async def run_server():
2424

2525
import anyio
2626
import anyio.lowlevel
27-
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
2827

2928
import mcp.types as types
29+
from mcp.shared.session import (
30+
ReadStream,
31+
ReadStreamWriter,
32+
WriteStream,
33+
WriteStreamReader,
34+
)
35+
from mcp.types import MessageFrame
3036

3137

3238
@asynccontextmanager
@@ -47,11 +53,11 @@ async def stdio_server(
4753
if not stdout:
4854
stdout = anyio.wrap_file(TextIOWrapper(sys.stdout.buffer, encoding="utf-8"))
4955

50-
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
51-
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
56+
read_stream: ReadStream
57+
read_stream_writer: ReadStreamWriter
5258

53-
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
54-
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
59+
write_stream: WriteStream
60+
write_stream_reader: WriteStreamReader
5561

5662
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
5763
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
@@ -66,14 +72,17 @@ async def stdin_reader():
6672
await read_stream_writer.send(exc)
6773
continue
6874

69-
await read_stream_writer.send(message)
75+
await read_stream_writer.send(
76+
MessageFrame(message=message, raw=line)
77+
)
7078
except anyio.ClosedResourceError:
7179
await anyio.lowlevel.checkpoint()
7280

7381
async def stdout_writer():
7482
try:
7583
async with write_stream_reader:
7684
async for message in write_stream_reader:
85+
# Extract the inner JSONRPCRequest/JSONRPCResponse from MessageFrame
7786
json = message.model_dump_json(by_alias=True, exclude_none=True)
7887
await stdout.write(json + "\n")
7988
await stdout.flush()

src/mcp/server/websocket.py

+14-6
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,17 @@
22
from contextlib import asynccontextmanager
33

44
import anyio
5-
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
65
from starlette.types import Receive, Scope, Send
76
from starlette.websockets import WebSocket
87

98
import mcp.types as types
9+
from mcp.shared.session import (
10+
ReadStream,
11+
ReadStreamWriter,
12+
WriteStream,
13+
WriteStreamReader,
14+
)
15+
from mcp.types import MessageFrame
1016

1117
logger = logging.getLogger(__name__)
1218

@@ -21,11 +27,11 @@ async def websocket_server(scope: Scope, receive: Receive, send: Send):
2127
websocket = WebSocket(scope, receive, send)
2228
await websocket.accept(subprotocol="mcp")
2329

24-
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
25-
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
30+
read_stream: ReadStream
31+
read_stream_writer: ReadStreamWriter
2632

27-
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
28-
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
33+
write_stream: WriteStream
34+
write_stream_reader: WriteStreamReader
2935

3036
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
3137
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
@@ -40,7 +46,9 @@ async def ws_reader():
4046
await read_stream_writer.send(exc)
4147
continue
4248

43-
await read_stream_writer.send(client_message)
49+
await read_stream_writer.send(
50+
MessageFrame(message=client_message, raw=message)
51+
)
4452
except anyio.ClosedResourceError:
4553
await websocket.close()
4654

src/mcp/shared/memory.py

+7-10
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111

1212
from mcp.client.session import ClientSession, ListRootsFnT, SamplingFnT
1313
from mcp.server import Server
14-
from mcp.types import JSONRPCMessage
14+
from mcp.types import MessageFrame
1515

1616
MessageStream = tuple[
17-
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
18-
MemoryObjectSendStream[JSONRPCMessage],
17+
MemoryObjectReceiveStream[MessageFrame | Exception],
18+
MemoryObjectSendStream[MessageFrame],
1919
]
2020

2121

@@ -32,10 +32,10 @@ async def create_client_server_memory_streams() -> (
3232
"""
3333
# Create streams for both directions
3434
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
35-
JSONRPCMessage | Exception
35+
MessageFrame | Exception
3636
](1)
3737
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
38-
JSONRPCMessage | Exception
38+
MessageFrame | Exception
3939
](1)
4040

4141
client_streams = (server_to_client_receive, client_to_server_send)
@@ -60,12 +60,9 @@ async def create_connected_server_and_client_session(
6060
) -> AsyncGenerator[ClientSession, None]:
6161
"""Creates a ClientSession that is connected to a running MCP server."""
6262
async with create_client_server_memory_streams() as (
63-
client_streams,
64-
server_streams,
63+
(client_read, client_write),
64+
(server_read, server_write),
6565
):
66-
client_read, client_write = client_streams
67-
server_read, server_write = server_streams
68-
6966
# Create a cancel scope for the server task
7067
async with anyio.create_task_group() as tg:
7168
tg.start_soon(

0 commit comments

Comments
 (0)