Skip to content

Commit 3af37d6

Browse files
committed
refactor: move streams to ParsedMessage
1 parent 8ea8bf5 commit 3af37d6

File tree

10 files changed

+79
-48
lines changed

10 files changed

+79
-48
lines changed

src/mcp/client/session.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
from datetime import timedelta
22
from typing import Any, Protocol
33

4-
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
54
from pydantic import AnyUrl, TypeAdapter
65

76
import mcp.types as types
87
from mcp.shared.context import RequestContext
9-
from mcp.shared.session import BaseSession, RequestResponder
8+
from mcp.shared.session import BaseSession, ReadStream, RequestResponder, WriteStream
109
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
1110

1211

@@ -59,8 +58,8 @@ class ClientSession(
5958
):
6059
def __init__(
6160
self,
62-
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
63-
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
61+
read_stream: ReadStream,
62+
write_stream: WriteStream,
6463
read_timeout_seconds: timedelta | None = None,
6564
sampling_callback: SamplingFnT | None = None,
6665
list_roots_callback: ListRootsFnT | None = None,

src/mcp/client/sse.py

+16-7
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,16 @@
66
import anyio
77
import httpx
88
from anyio.abc import TaskStatus
9-
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
109
from httpx_sse import aconnect_sse
1110

1211
import mcp.types as types
12+
from mcp.shared.session import (
13+
ParsedMessage,
14+
ReadStream,
15+
ReadStreamWriter,
16+
WriteStream,
17+
WriteStreamReader,
18+
)
1319

1420
logger = logging.getLogger(__name__)
1521

@@ -31,11 +37,11 @@ async def sse_client(
3137
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
3238
event before disconnecting. All other HTTP operations are controlled by `timeout`.
3339
"""
34-
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
35-
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
40+
read_stream: ReadStream
41+
read_stream_writer: ReadStreamWriter
3642

37-
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
38-
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
43+
write_stream: WriteStream
44+
write_stream_reader: WriteStreamReader
3945

4046
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
4147
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
@@ -84,8 +90,11 @@ async def sse_reader(
8490

8591
case "message":
8692
try:
87-
message = types.JSONRPCMessage.model_validate_json( # noqa: E501
88-
sse.data
93+
message = ParsedMessage(
94+
types.JSONRPCMessage.model_validate_json( # noqa: E501
95+
sse.data
96+
),
97+
raw=sse,
8998
)
9099
logger.debug(
91100
f"Received server message: {message}"

src/mcp/server/lowlevel/server.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -78,12 +78,12 @@ async def main():
7878

7979
import mcp.types as types
8080
from mcp.server.lowlevel.helper_types import ReadResourceContents
81-
from mcp.server.models import InitializationOptions, ReadStream, WriteStream
81+
from mcp.server.models import InitializationOptions
8282
from mcp.server.session import ServerSession
8383
from mcp.server.stdio import stdio_server as stdio_server
8484
from mcp.shared.context import RequestContext
8585
from mcp.shared.exceptions import McpError
86-
from mcp.shared.session import RequestResponder
86+
from mcp.shared.session import ReadStream, RequestResponder, WriteStream
8787

8888
logger = logging.getLogger(__name__)
8989

src/mcp/server/models.py

+1-7
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,9 @@
33
and tools.
44
"""
55

6-
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
76
from pydantic import BaseModel
87

9-
from mcp.types import JSONRPCMessage, ServerCapabilities
10-
11-
ReadStream = MemoryObjectReceiveStream[JSONRPCMessage | Exception]
12-
ReadStreamWriter = MemoryObjectSendStream[JSONRPCMessage | Exception]
13-
WriteStream = MemoryObjectSendStream[JSONRPCMessage]
14-
WriteStreamReader = MemoryObjectReceiveStream[JSONRPCMessage]
8+
from mcp.types import ServerCapabilities
159

1610

1711
class InitializationOptions(BaseModel):

src/mcp/server/session.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,12 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]:
4545
from pydantic import AnyUrl
4646

4747
import mcp.types as types
48-
from mcp.server.models import InitializationOptions, ReadStream, WriteStream
48+
from mcp.server.models import InitializationOptions
4949
from mcp.shared.session import (
5050
BaseSession,
51+
ReadStream,
5152
RequestResponder,
53+
WriteStream,
5254
)
5355

5456

src/mcp/server/sse.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ async def handle_sse(request):
4545
from starlette.types import Receive, Scope, Send
4646

4747
import mcp.types as types
48-
from mcp.server.models import (
48+
from mcp.shared.session import (
49+
ParsedMessage,
4950
ReadStream,
5051
ReadStreamWriter,
5152
WriteStream,
@@ -175,4 +176,4 @@ async def handle_post_message(
175176
logger.debug(f"Sending message to writer: {message}")
176177
response = Response("Accepted", status_code=202)
177178
await response(scope, receive, send)
178-
await writer.send(message)
179+
await writer.send(ParsedMessage(message, raw=request))

src/mcp/server/stdio.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ async def run_server():
2626
import anyio.lowlevel
2727

2828
import mcp.types as types
29-
from mcp.server.models import (
29+
from mcp.shared.session import (
30+
ParsedMessage,
3031
ReadStream,
3132
ReadStreamWriter,
3233
WriteStream,
@@ -71,7 +72,7 @@ async def stdin_reader():
7172
await read_stream_writer.send(exc)
7273
continue
7374

74-
await read_stream_writer.send(message)
75+
await read_stream_writer.send(ParsedMessage(message, raw=line))
7576
except anyio.ClosedResourceError:
7677
await anyio.lowlevel.checkpoint()
7778

src/mcp/server/websocket.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from starlette.websockets import WebSocket
77

88
import mcp.types as types
9-
from mcp.server.models import (
9+
from mcp.shared.session import (
10+
ParsedMessage,
1011
ReadStream,
1112
ReadStreamWriter,
1213
WriteStream,
@@ -45,7 +46,9 @@ async def ws_reader():
4546
await read_stream_writer.send(exc)
4647
continue
4748

48-
await read_stream_writer.send(client_message)
49+
await read_stream_writer.send(
50+
ParsedMessage(client_message, raw=message)
51+
)
4952
except anyio.ClosedResourceError:
5053
await websocket.close()
5154

src/mcp/shared/memory.py

+7-10
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111

1212
from mcp.client.session import ClientSession, ListRootsFnT, SamplingFnT
1313
from mcp.server import Server
14-
from mcp.types import JSONRPCMessage
14+
from mcp.shared.session import ParsedMessage
1515

1616
MessageStream = tuple[
17-
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
18-
MemoryObjectSendStream[JSONRPCMessage],
17+
MemoryObjectReceiveStream[ParsedMessage | Exception],
18+
MemoryObjectSendStream[ParsedMessage],
1919
]
2020

2121

@@ -32,10 +32,10 @@ async def create_client_server_memory_streams() -> (
3232
"""
3333
# Create streams for both directions
3434
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
35-
JSONRPCMessage | Exception
35+
ParsedMessage | Exception
3636
](1)
3737
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
38-
JSONRPCMessage | Exception
38+
ParsedMessage | Exception
3939
](1)
4040

4141
client_streams = (server_to_client_receive, client_to_server_send)
@@ -60,12 +60,9 @@ async def create_connected_server_and_client_session(
6060
) -> AsyncGenerator[ClientSession, None]:
6161
"""Creates a ClientSession that is connected to a running MCP server."""
6262
async with create_client_server_memory_streams() as (
63-
client_streams,
64-
server_streams,
63+
(client_read, client_write),
64+
(server_read, server_write),
6565
):
66-
client_read, client_write = client_streams
67-
server_read, server_write = server_streams
68-
6966
# Create a cancel scope for the server task
7067
async with anyio.create_task_group() as tg:
7168
tg.start_soon(

src/mcp/shared/session.py

+36-11
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import anyio.lowlevel
88
import httpx
99
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
10-
from pydantic import BaseModel
10+
from pydantic import BaseModel, RootModel
1111
from typing_extensions import Self
1212

1313
from mcp.shared.exceptions import McpError
@@ -28,6 +28,22 @@
2828
ServerResult,
2929
)
3030

31+
RawT = TypeVar("RawT")
32+
33+
34+
class ParsedMessage(RootModel[JSONRPCMessage], Generic[RawT]):
35+
root: JSONRPCMessage
36+
raw: RawT | None = None
37+
38+
class Config:
39+
arbitrary_types_allowed = True
40+
41+
42+
ReadStream = MemoryObjectReceiveStream[ParsedMessage[RawT] | Exception]
43+
ReadStreamWriter = MemoryObjectSendStream[ParsedMessage[RawT] | Exception]
44+
WriteStream = MemoryObjectSendStream[ParsedMessage[RawT]]
45+
WriteStreamReader = MemoryObjectReceiveStream[ParsedMessage[RawT]]
46+
3147
SendRequestT = TypeVar("SendRequestT", ClientRequest, ServerRequest)
3248
SendResultT = TypeVar("SendResultT", ClientResult, ServerResult)
3349
SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotification)
@@ -165,8 +181,8 @@ class BaseSession(
165181

166182
def __init__(
167183
self,
168-
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
169-
write_stream: MemoryObjectSendStream[JSONRPCMessage],
184+
read_stream: ReadStream,
185+
write_stream: WriteStream,
170186
receive_request_type: type[ReceiveRequestT],
171187
receive_notification_type: type[ReceiveNotificationT],
172188
# If none, reading will never time out
@@ -242,7 +258,9 @@ async def send_request(
242258

243259
# TODO: Support progress callbacks
244260

245-
await self._write_stream.send(JSONRPCMessage(jsonrpc_request))
261+
await self._write_stream.send(
262+
ParsedMessage(JSONRPCMessage(jsonrpc_request), None)
263+
)
246264

247265
try:
248266
with anyio.fail_after(
@@ -278,14 +296,16 @@ async def send_notification(self, notification: SendNotificationT) -> None:
278296
**notification.model_dump(by_alias=True, mode="json", exclude_none=True),
279297
)
280298

281-
await self._write_stream.send(JSONRPCMessage(jsonrpc_notification))
299+
await self._write_stream.send(
300+
ParsedMessage(JSONRPCMessage(jsonrpc_notification))
301+
)
282302

283303
async def _send_response(
284304
self, request_id: RequestId, response: SendResultT | ErrorData
285305
) -> None:
286306
if isinstance(response, ErrorData):
287307
jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response)
288-
await self._write_stream.send(JSONRPCMessage(jsonrpc_error))
308+
await self._write_stream.send(ParsedMessage(JSONRPCMessage(jsonrpc_error)))
289309
else:
290310
jsonrpc_response = JSONRPCResponse(
291311
jsonrpc="2.0",
@@ -294,18 +314,23 @@ async def _send_response(
294314
by_alias=True, mode="json", exclude_none=True
295315
),
296316
)
297-
await self._write_stream.send(JSONRPCMessage(jsonrpc_response))
317+
await self._write_stream.send(
318+
ParsedMessage(JSONRPCMessage(jsonrpc_response))
319+
)
298320

299321
async def _receive_loop(self) -> None:
300322
async with (
301323
self._read_stream,
302324
self._write_stream,
303325
self._incoming_message_stream_writer,
304326
):
305-
async for message in self._read_stream:
306-
if isinstance(message, Exception):
307-
await self._incoming_message_stream_writer.send(message)
308-
elif isinstance(message.root, JSONRPCRequest):
327+
async for raw_message in self._read_stream:
328+
if isinstance(raw_message, Exception):
329+
await self._incoming_message_stream_writer.send(raw_message)
330+
continue
331+
332+
message = raw_message.root
333+
if isinstance(message.root, JSONRPCRequest):
309334
validated_request = self._receive_request_type.model_validate(
310335
message.root.model_dump(
311336
by_alias=True, mode="json", exclude_none=True

0 commit comments

Comments
 (0)