Skip to content

Commit 7196604

Browse files
authored
Revert "refactor: reorganize message handling for better type safety and clar…" (#282)
This reverts commit 9d0f2da.
1 parent ebb81d3 commit 7196604

17 files changed

+151
-283
lines changed

src/mcp/client/session.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from datetime import timedelta
22
from typing import Any, Protocol
33

4+
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
45
from pydantic import AnyUrl, TypeAdapter
56

67
import mcp.types as types
78
from mcp.shared.context import RequestContext
8-
from mcp.shared.session import BaseSession, ReadStream, RequestResponder, WriteStream
9+
from mcp.shared.session import BaseSession, RequestResponder
910
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
1011

1112

@@ -58,8 +59,8 @@ class ClientSession(
5859
):
5960
def __init__(
6061
self,
61-
read_stream: ReadStream,
62-
write_stream: WriteStream,
62+
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
63+
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
6364
read_timeout_seconds: timedelta | None = None,
6465
sampling_callback: SamplingFnT | None = None,
6566
list_roots_callback: ListRootsFnT | None = None,

src/mcp/client/sse.py

+7-16
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,10 @@
66
import anyio
77
import httpx
88
from anyio.abc import TaskStatus
9+
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
910
from httpx_sse import aconnect_sse
1011

1112
import mcp.types as types
12-
from mcp.shared.session import (
13-
ReadStream,
14-
ReadStreamWriter,
15-
WriteStream,
16-
WriteStreamReader,
17-
)
18-
from mcp.types import MessageFrame
1913

2014
logger = logging.getLogger(__name__)
2115

@@ -37,11 +31,11 @@ async def sse_client(
3731
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
3832
event before disconnecting. All other HTTP operations are controlled by `timeout`.
3933
"""
40-
read_stream: ReadStream
41-
read_stream_writer: ReadStreamWriter
34+
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
35+
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
4236

43-
write_stream: WriteStream
44-
write_stream_reader: WriteStreamReader
37+
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
38+
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
4539

4640
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
4741
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
@@ -90,11 +84,8 @@ async def sse_reader(
9084

9185
case "message":
9286
try:
93-
message = MessageFrame(
94-
message=types.JSONRPCMessage.model_validate_json( # noqa: E501
95-
sse.data
96-
),
97-
raw=sse,
87+
message = types.JSONRPCMessage.model_validate_json( # noqa: E501
88+
sse.data
9889
)
9990
logger.debug(
10091
f"Received server message: {message}"

src/mcp/client/websocket.py

+6-11
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import json
22
import logging
33
from contextlib import asynccontextmanager
4-
from typing import Any, AsyncGenerator
4+
from typing import AsyncGenerator
55

66
import anyio
77
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
@@ -10,7 +10,6 @@
1010
from websockets.typing import Subprotocol
1111

1212
import mcp.types as types
13-
from mcp.types import MessageFrame
1413

1514
logger = logging.getLogger(__name__)
1615

@@ -20,8 +19,8 @@ async def websocket_client(
2019
url: str,
2120
) -> AsyncGenerator[
2221
tuple[
23-
MemoryObjectReceiveStream[MessageFrame[Any] | Exception],
24-
MemoryObjectSendStream[MessageFrame[Any]],
22+
MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
23+
MemoryObjectSendStream[types.JSONRPCMessage],
2524
],
2625
None,
2726
]:
@@ -54,11 +53,7 @@ async def ws_reader():
5453
async with read_stream_writer:
5554
async for raw_text in ws:
5655
try:
57-
json_message = types.JSONRPCMessage.model_validate_json(
58-
raw_text
59-
)
60-
# Create MessageFrame with JSON message as root
61-
message = MessageFrame(message=json_message, raw=raw_text)
56+
message = types.JSONRPCMessage.model_validate_json(raw_text)
6257
await read_stream_writer.send(message)
6358
except ValidationError as exc:
6459
# If JSON parse or model validation fails, send the exception
@@ -71,8 +66,8 @@ async def ws_writer():
7166
"""
7267
async with write_stream_reader:
7368
async for message in write_stream_reader:
74-
# Extract the JSON-RPC message from MessageFrame and convert to JSON
75-
msg_dict = message.message.model_dump(
69+
# Convert to a dict, then to JSON
70+
msg_dict = message.model_dump(
7671
by_alias=True, mode="json", exclude_none=True
7772
)
7873
await ws.send(json.dumps(msg_dict))

src/mcp/server/lowlevel/server.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ async def main():
7474
from typing import Any, AsyncIterator, Generic, TypeVar
7575

7676
import anyio
77+
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
7778
from pydantic import AnyUrl
7879

7980
import mcp.types as types
@@ -83,7 +84,7 @@ async def main():
8384
from mcp.server.stdio import stdio_server as stdio_server
8485
from mcp.shared.context import RequestContext
8586
from mcp.shared.exceptions import McpError
86-
from mcp.shared.session import ReadStream, RequestResponder, WriteStream
87+
from mcp.shared.session import RequestResponder
8788

8889
logger = logging.getLogger(__name__)
8990

@@ -473,8 +474,8 @@ async def handler(req: types.CompleteRequest):
473474

474475
async def run(
475476
self,
476-
read_stream: ReadStream,
477-
write_stream: WriteStream,
477+
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
478+
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
478479
initialization_options: InitializationOptions,
479480
# When False, exceptions are returned as messages to the client.
480481
# When True, exceptions are raised, which will cause the server to shut down

src/mcp/server/models.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55

66
from pydantic import BaseModel
77

8-
from mcp.types import ServerCapabilities
8+
from mcp.types import (
9+
ServerCapabilities,
10+
)
911

1012

1113
class InitializationOptions(BaseModel):

src/mcp/server/session.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,14 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]:
4242

4343
import anyio
4444
import anyio.lowlevel
45+
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
4546
from pydantic import AnyUrl
4647

4748
import mcp.types as types
4849
from mcp.server.models import InitializationOptions
4950
from mcp.shared.session import (
5051
BaseSession,
51-
ReadStream,
5252
RequestResponder,
53-
WriteStream,
5453
)
5554

5655

@@ -77,8 +76,8 @@ class ServerSession(
7776

7877
def __init__(
7978
self,
80-
read_stream: ReadStream,
81-
write_stream: WriteStream,
79+
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
80+
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
8281
init_options: InitializationOptions,
8382
) -> None:
8483
super().__init__(

src/mcp/server/sse.py

+9-13
Original file line numberDiff line numberDiff line change
@@ -38,20 +38,14 @@ async def handle_sse(request):
3838
from uuid import UUID, uuid4
3939

4040
import anyio
41+
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
4142
from pydantic import ValidationError
4243
from sse_starlette import EventSourceResponse
4344
from starlette.requests import Request
4445
from starlette.responses import Response
4546
from starlette.types import Receive, Scope, Send
4647

4748
import mcp.types as types
48-
from mcp.shared.session import (
49-
ReadStream,
50-
ReadStreamWriter,
51-
WriteStream,
52-
WriteStreamReader,
53-
)
54-
from mcp.types import MessageFrame
5549

5650
logger = logging.getLogger(__name__)
5751

@@ -69,7 +63,9 @@ class SseServerTransport:
6963
"""
7064

7165
_endpoint: str
72-
_read_stream_writers: dict[UUID, ReadStreamWriter]
66+
_read_stream_writers: dict[
67+
UUID, MemoryObjectSendStream[types.JSONRPCMessage | Exception]
68+
]
7369

7470
def __init__(self, endpoint: str) -> None:
7571
"""
@@ -89,11 +85,11 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send):
8985
raise ValueError("connect_sse can only handle HTTP requests")
9086

9187
logger.debug("Setting up SSE connection")
92-
read_stream: ReadStream
93-
read_stream_writer: ReadStreamWriter
88+
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
89+
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
9490

95-
write_stream: WriteStream
96-
write_stream_reader: WriteStreamReader
91+
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
92+
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
9793

9894
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
9995
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
@@ -176,4 +172,4 @@ async def handle_post_message(
176172
logger.debug(f"Sending message to writer: {message}")
177173
response = Response("Accepted", status_code=202)
178174
await response(scope, receive, send)
179-
await writer.send(MessageFrame(message=message, raw=request))
175+
await writer.send(message)

src/mcp/server/stdio.py

+6-15
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,9 @@ async def run_server():
2424

2525
import anyio
2626
import anyio.lowlevel
27+
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
2728

2829
import mcp.types as types
29-
from mcp.shared.session import (
30-
ReadStream,
31-
ReadStreamWriter,
32-
WriteStream,
33-
WriteStreamReader,
34-
)
35-
from mcp.types import MessageFrame
3630

3731

3832
@asynccontextmanager
@@ -53,11 +47,11 @@ async def stdio_server(
5347
if not stdout:
5448
stdout = anyio.wrap_file(TextIOWrapper(sys.stdout.buffer, encoding="utf-8"))
5549

56-
read_stream: ReadStream
57-
read_stream_writer: ReadStreamWriter
50+
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
51+
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
5852

59-
write_stream: WriteStream
60-
write_stream_reader: WriteStreamReader
53+
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
54+
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
6155

6256
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
6357
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
@@ -72,17 +66,14 @@ async def stdin_reader():
7266
await read_stream_writer.send(exc)
7367
continue
7468

75-
await read_stream_writer.send(
76-
MessageFrame(message=message, raw=line)
77-
)
69+
await read_stream_writer.send(message)
7870
except anyio.ClosedResourceError:
7971
await anyio.lowlevel.checkpoint()
8072

8173
async def stdout_writer():
8274
try:
8375
async with write_stream_reader:
8476
async for message in write_stream_reader:
85-
# Extract the inner JSONRPCRequest/JSONRPCResponse from MessageFrame
8677
json = message.model_dump_json(by_alias=True, exclude_none=True)
8778
await stdout.write(json + "\n")
8879
await stdout.flush()

src/mcp/server/websocket.py

+6-14
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,11 @@
22
from contextlib import asynccontextmanager
33

44
import anyio
5+
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
56
from starlette.types import Receive, Scope, Send
67
from starlette.websockets import WebSocket
78

89
import mcp.types as types
9-
from mcp.shared.session import (
10-
ReadStream,
11-
ReadStreamWriter,
12-
WriteStream,
13-
WriteStreamReader,
14-
)
15-
from mcp.types import MessageFrame
1610

1711
logger = logging.getLogger(__name__)
1812

@@ -27,11 +21,11 @@ async def websocket_server(scope: Scope, receive: Receive, send: Send):
2721
websocket = WebSocket(scope, receive, send)
2822
await websocket.accept(subprotocol="mcp")
2923

30-
read_stream: ReadStream
31-
read_stream_writer: ReadStreamWriter
24+
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
25+
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
3226

33-
write_stream: WriteStream
34-
write_stream_reader: WriteStreamReader
27+
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
28+
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
3529

3630
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
3731
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
@@ -46,9 +40,7 @@ async def ws_reader():
4640
await read_stream_writer.send(exc)
4741
continue
4842

49-
await read_stream_writer.send(
50-
MessageFrame(message=client_message, raw=message)
51-
)
43+
await read_stream_writer.send(client_message)
5244
except anyio.ClosedResourceError:
5345
await websocket.close()
5446

src/mcp/shared/memory.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111

1212
from mcp.client.session import ClientSession, ListRootsFnT, SamplingFnT
1313
from mcp.server import Server
14-
from mcp.types import MessageFrame
14+
from mcp.types import JSONRPCMessage
1515

1616
MessageStream = tuple[
17-
MemoryObjectReceiveStream[MessageFrame | Exception],
18-
MemoryObjectSendStream[MessageFrame],
17+
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
18+
MemoryObjectSendStream[JSONRPCMessage],
1919
]
2020

2121

@@ -32,10 +32,10 @@ async def create_client_server_memory_streams() -> (
3232
"""
3333
# Create streams for both directions
3434
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
35-
MessageFrame | Exception
35+
JSONRPCMessage | Exception
3636
](1)
3737
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
38-
MessageFrame | Exception
38+
JSONRPCMessage | Exception
3939
](1)
4040

4141
client_streams = (server_to_client_receive, client_to_server_send)
@@ -60,9 +60,12 @@ async def create_connected_server_and_client_session(
6060
) -> AsyncGenerator[ClientSession, None]:
6161
"""Creates a ClientSession that is connected to a running MCP server."""
6262
async with create_client_server_memory_streams() as (
63-
(client_read, client_write),
64-
(server_read, server_write),
63+
client_streams,
64+
server_streams,
6565
):
66+
client_read, client_write = client_streams
67+
server_read, server_write = server_streams
68+
6669
# Create a cancel scope for the server task
6770
async with anyio.create_task_group() as tg:
6871
tg.start_soon(

0 commit comments

Comments
 (0)