Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 8c9285e

Browse files
committedFeb 26, 2025
refactor: move streams to ParsedMessage
1 parent 6dea29b commit 8c9285e

File tree

10 files changed

+79
-48
lines changed

10 files changed

+79
-48
lines changed
 

‎src/mcp/client/session.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
from datetime import timedelta
22
from typing import Any, Protocol
33

4-
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
54
from pydantic import AnyUrl, TypeAdapter
65

76
import mcp.types as types
87
from mcp.shared.context import RequestContext
9-
from mcp.shared.session import BaseSession, RequestResponder
8+
from mcp.shared.session import BaseSession, ReadStream, RequestResponder, WriteStream
109
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
1110

1211

@@ -57,8 +56,8 @@ class ClientSession(
5756
):
5857
def __init__(
5958
self,
60-
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
61-
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
59+
read_stream: ReadStream,
60+
write_stream: WriteStream,
6261
read_timeout_seconds: timedelta | None = None,
6362
sampling_callback: SamplingFnT | None = None,
6463
list_roots_callback: ListRootsFnT | None = None,

‎src/mcp/client/sse.py

+16-7
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,16 @@
66
import anyio
77
import httpx
88
from anyio.abc import TaskStatus
9-
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
109
from httpx_sse import aconnect_sse
1110

1211
import mcp.types as types
12+
from mcp.shared.session import (
13+
ParsedMessage,
14+
ReadStream,
15+
ReadStreamWriter,
16+
WriteStream,
17+
WriteStreamReader,
18+
)
1319

1420
logger = logging.getLogger(__name__)
1521

@@ -31,11 +37,11 @@ async def sse_client(
3137
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
3238
event before disconnecting. All other HTTP operations are controlled by `timeout`.
3339
"""
34-
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
35-
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
40+
read_stream: ReadStream
41+
read_stream_writer: ReadStreamWriter
3642

37-
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
38-
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
43+
write_stream: WriteStream
44+
write_stream_reader: WriteStreamReader
3945

4046
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
4147
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
@@ -84,8 +90,11 @@ async def sse_reader(
8490

8591
case "message":
8692
try:
87-
message = types.JSONRPCMessage.model_validate_json( # noqa: E501
88-
sse.data
93+
message = ParsedMessage(
94+
types.JSONRPCMessage.model_validate_json( # noqa: E501
95+
sse.data
96+
),
97+
raw=sse,
8998
)
9099
logger.debug(
91100
f"Received server message: {message}"

‎src/mcp/server/lowlevel/server.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -76,12 +76,12 @@ async def main():
7676

7777
import mcp.types as types
7878
from mcp.server.lowlevel.helper_types import ReadResourceContents
79-
from mcp.server.models import InitializationOptions, ReadStream, WriteStream
79+
from mcp.server.models import InitializationOptions
8080
from mcp.server.session import ServerSession
8181
from mcp.server.stdio import stdio_server as stdio_server
8282
from mcp.shared.context import RequestContext
8383
from mcp.shared.exceptions import McpError
84-
from mcp.shared.session import RequestResponder
84+
from mcp.shared.session import ReadStream, RequestResponder, WriteStream
8585

8686
logger = logging.getLogger(__name__)
8787

‎src/mcp/server/models.py

+1-7
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,9 @@
33
and tools.
44
"""
55

6-
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
76
from pydantic import BaseModel
87

9-
from mcp.types import JSONRPCMessage, ServerCapabilities
10-
11-
ReadStream = MemoryObjectReceiveStream[JSONRPCMessage | Exception]
12-
ReadStreamWriter = MemoryObjectSendStream[JSONRPCMessage | Exception]
13-
WriteStream = MemoryObjectSendStream[JSONRPCMessage]
14-
WriteStreamReader = MemoryObjectReceiveStream[JSONRPCMessage]
8+
from mcp.types import ServerCapabilities
159

1610

1711
class InitializationOptions(BaseModel):

‎src/mcp/server/session.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,12 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]:
4545
from pydantic import AnyUrl
4646

4747
import mcp.types as types
48-
from mcp.server.models import InitializationOptions, ReadStream, WriteStream
48+
from mcp.server.models import InitializationOptions
4949
from mcp.shared.session import (
5050
BaseSession,
51+
ReadStream,
5152
RequestResponder,
53+
WriteStream,
5254
)
5355

5456

‎src/mcp/server/sse.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ async def handle_sse(request):
4545
from starlette.types import Receive, Scope, Send
4646

4747
import mcp.types as types
48-
from mcp.server.models import (
48+
from mcp.shared.session import (
49+
ParsedMessage,
4950
ReadStream,
5051
ReadStreamWriter,
5152
WriteStream,
@@ -175,4 +176,4 @@ async def handle_post_message(
175176
logger.debug(f"Sending message to writer: {message}")
176177
response = Response("Accepted", status_code=202)
177178
await response(scope, receive, send)
178-
await writer.send(message)
179+
await writer.send(ParsedMessage(message, raw=request))

‎src/mcp/server/stdio.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ async def run_server():
2626
import anyio.lowlevel
2727

2828
import mcp.types as types
29-
from mcp.server.models import (
29+
from mcp.shared.session import (
30+
ParsedMessage,
3031
ReadStream,
3132
ReadStreamWriter,
3233
WriteStream,
@@ -71,7 +72,7 @@ async def stdin_reader():
7172
await read_stream_writer.send(exc)
7273
continue
7374

74-
await read_stream_writer.send(message)
75+
await read_stream_writer.send(ParsedMessage(message, raw=line))
7576
except anyio.ClosedResourceError:
7677
await anyio.lowlevel.checkpoint()
7778

‎src/mcp/server/websocket.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from starlette.websockets import WebSocket
77

88
import mcp.types as types
9-
from mcp.server.models import (
9+
from mcp.shared.session import (
10+
ParsedMessage,
1011
ReadStream,
1112
ReadStreamWriter,
1213
WriteStream,
@@ -45,7 +46,9 @@ async def ws_reader():
4546
await read_stream_writer.send(exc)
4647
continue
4748

48-
await read_stream_writer.send(client_message)
49+
await read_stream_writer.send(
50+
ParsedMessage(client_message, raw=message)
51+
)
4952
except anyio.ClosedResourceError:
5053
await websocket.close()
5154

‎src/mcp/shared/memory.py

+7-10
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111

1212
from mcp.client.session import ClientSession, ListRootsFnT, SamplingFnT
1313
from mcp.server import Server
14-
from mcp.types import JSONRPCMessage
14+
from mcp.shared.session import ParsedMessage
1515

1616
MessageStream = tuple[
17-
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
18-
MemoryObjectSendStream[JSONRPCMessage],
17+
MemoryObjectReceiveStream[ParsedMessage | Exception],
18+
MemoryObjectSendStream[ParsedMessage],
1919
]
2020

2121

@@ -32,10 +32,10 @@ async def create_client_server_memory_streams() -> (
3232
"""
3333
# Create streams for both directions
3434
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
35-
JSONRPCMessage | Exception
35+
ParsedMessage | Exception
3636
](1)
3737
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
38-
JSONRPCMessage | Exception
38+
ParsedMessage | Exception
3939
](1)
4040

4141
client_streams = (server_to_client_receive, client_to_server_send)
@@ -60,12 +60,9 @@ async def create_connected_server_and_client_session(
6060
) -> AsyncGenerator[ClientSession, None]:
6161
"""Creates a ClientSession that is connected to a running MCP server."""
6262
async with create_client_server_memory_streams() as (
63-
client_streams,
64-
server_streams,
63+
(client_read, client_write),
64+
(server_read, server_write),
6565
):
66-
client_read, client_write = client_streams
67-
server_read, server_write = server_streams
68-
6966
# Create a cancel scope for the server task
7067
async with anyio.create_task_group() as tg:
7168
tg.start_soon(

‎src/mcp/shared/session.py

+36-11
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import anyio.lowlevel
88
import httpx
99
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
10-
from pydantic import BaseModel
10+
from pydantic import BaseModel, RootModel
1111

1212
from mcp.shared.exceptions import McpError
1313
from mcp.types import (
@@ -27,6 +27,22 @@
2727
ServerResult,
2828
)
2929

30+
RawT = TypeVar("RawT")
31+
32+
33+
class ParsedMessage(RootModel[JSONRPCMessage], Generic[RawT]):
34+
root: JSONRPCMessage
35+
raw: RawT | None = None
36+
37+
class Config:
38+
arbitrary_types_allowed = True
39+
40+
41+
ReadStream = MemoryObjectReceiveStream[ParsedMessage[RawT] | Exception]
42+
ReadStreamWriter = MemoryObjectSendStream[ParsedMessage[RawT] | Exception]
43+
WriteStream = MemoryObjectSendStream[ParsedMessage[RawT]]
44+
WriteStreamReader = MemoryObjectReceiveStream[ParsedMessage[RawT]]
45+
3046
SendRequestT = TypeVar("SendRequestT", ClientRequest, ServerRequest)
3147
SendResultT = TypeVar("SendResultT", ClientResult, ServerResult)
3248
SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotification)
@@ -159,8 +175,8 @@ class BaseSession(
159175

160176
def __init__(
161177
self,
162-
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
163-
write_stream: MemoryObjectSendStream[JSONRPCMessage],
178+
read_stream: ReadStream,
179+
write_stream: WriteStream,
164180
receive_request_type: type[ReceiveRequestT],
165181
receive_notification_type: type[ReceiveNotificationT],
166182
# If none, reading will never time out
@@ -225,7 +241,9 @@ async def send_request(
225241

226242
# TODO: Support progress callbacks
227243

228-
await self._write_stream.send(JSONRPCMessage(jsonrpc_request))
244+
await self._write_stream.send(
245+
ParsedMessage(JSONRPCMessage(jsonrpc_request), None)
246+
)
229247

230248
try:
231249
with anyio.fail_after(
@@ -261,14 +279,16 @@ async def send_notification(self, notification: SendNotificationT) -> None:
261279
**notification.model_dump(by_alias=True, mode="json", exclude_none=True),
262280
)
263281

264-
await self._write_stream.send(JSONRPCMessage(jsonrpc_notification))
282+
await self._write_stream.send(
283+
ParsedMessage(JSONRPCMessage(jsonrpc_notification))
284+
)
265285

266286
async def _send_response(
267287
self, request_id: RequestId, response: SendResultT | ErrorData
268288
) -> None:
269289
if isinstance(response, ErrorData):
270290
jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response)
271-
await self._write_stream.send(JSONRPCMessage(jsonrpc_error))
291+
await self._write_stream.send(ParsedMessage(JSONRPCMessage(jsonrpc_error)))
272292
else:
273293
jsonrpc_response = JSONRPCResponse(
274294
jsonrpc="2.0",
@@ -277,18 +297,23 @@ async def _send_response(
277297
by_alias=True, mode="json", exclude_none=True
278298
),
279299
)
280-
await self._write_stream.send(JSONRPCMessage(jsonrpc_response))
300+
await self._write_stream.send(
301+
ParsedMessage(JSONRPCMessage(jsonrpc_response))
302+
)
281303

282304
async def _receive_loop(self) -> None:
283305
async with (
284306
self._read_stream,
285307
self._write_stream,
286308
self._incoming_message_stream_writer,
287309
):
288-
async for message in self._read_stream:
289-
if isinstance(message, Exception):
290-
await self._incoming_message_stream_writer.send(message)
291-
elif isinstance(message.root, JSONRPCRequest):
310+
async for raw_message in self._read_stream:
311+
if isinstance(raw_message, Exception):
312+
await self._incoming_message_stream_writer.send(raw_message)
313+
continue
314+
315+
message = raw_message.root
316+
if isinstance(message.root, JSONRPCRequest):
292317
validated_request = self._receive_request_type.model_validate(
293318
message.root.model_dump(
294319
by_alias=True, mode="json", exclude_none=True

0 commit comments

Comments
 (0)
Please sign in to comment.