Skip to content

Commit 6dea29b

Browse files
committed
refactor: improve typing with memory stream type aliases
Move memory stream type definitions to models.py and use them throughout the codebase for better type safety and maintainability. GitHub-Issue:#201
1 parent 775f879 commit 6dea29b

File tree

6 files changed

+44
-29
lines changed

6 files changed

+44
-29
lines changed

src/mcp/server/lowlevel/server.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,11 @@ async def main():
7272
from typing import Any, AsyncIterator, Generic, TypeVar
7373

7474
import anyio
75-
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
7675
from pydantic import AnyUrl
7776

7877
import mcp.types as types
7978
from mcp.server.lowlevel.helper_types import ReadResourceContents
80-
from mcp.server.models import InitializationOptions
79+
from mcp.server.models import InitializationOptions, ReadStream, WriteStream
8180
from mcp.server.session import ServerSession
8281
from mcp.server.stdio import stdio_server as stdio_server
8382
from mcp.shared.context import RequestContext
@@ -472,8 +471,8 @@ async def handler(req: types.CompleteRequest):
472471

473472
async def run(
474473
self,
475-
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
476-
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
474+
read_stream: ReadStream,
475+
write_stream: WriteStream,
477476
initialization_options: InitializationOptions,
478477
# When False, exceptions are returned as messages to the client.
479478
# When True, exceptions are raised, which will cause the server to shut down

src/mcp/server/models.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,15 @@
33
and tools.
44
"""
55

6+
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
67
from pydantic import BaseModel
78

8-
from mcp.types import (
9-
ServerCapabilities,
10-
)
9+
from mcp.types import JSONRPCMessage, ServerCapabilities
10+
11+
ReadStream = MemoryObjectReceiveStream[JSONRPCMessage | Exception]
12+
ReadStreamWriter = MemoryObjectSendStream[JSONRPCMessage | Exception]
13+
WriteStream = MemoryObjectSendStream[JSONRPCMessage]
14+
WriteStreamReader = MemoryObjectReceiveStream[JSONRPCMessage]
1115

1216

1317
class InitializationOptions(BaseModel):

src/mcp/server/session.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,10 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]:
4242

4343
import anyio
4444
import anyio.lowlevel
45-
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
4645
from pydantic import AnyUrl
4746

4847
import mcp.types as types
49-
from mcp.server.models import InitializationOptions
48+
from mcp.server.models import InitializationOptions, ReadStream, WriteStream
5049
from mcp.shared.session import (
5150
BaseSession,
5251
RequestResponder,
@@ -73,8 +72,8 @@ class ServerSession(
7372

7473
def __init__(
7574
self,
76-
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
77-
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
75+
read_stream: ReadStream,
76+
write_stream: WriteStream,
7877
init_options: InitializationOptions,
7978
) -> None:
8079
super().__init__(

src/mcp/server/sse.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,19 @@ async def handle_sse(request):
3838
from uuid import UUID, uuid4
3939

4040
import anyio
41-
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
4241
from pydantic import ValidationError
4342
from sse_starlette import EventSourceResponse
4443
from starlette.requests import Request
4544
from starlette.responses import Response
4645
from starlette.types import Receive, Scope, Send
4746

4847
import mcp.types as types
48+
from mcp.server.models import (
49+
ReadStream,
50+
ReadStreamWriter,
51+
WriteStream,
52+
WriteStreamReader,
53+
)
4954

5055
logger = logging.getLogger(__name__)
5156

@@ -63,9 +68,7 @@ class SseServerTransport:
6368
"""
6469

6570
_endpoint: str
66-
_read_stream_writers: dict[
67-
UUID, MemoryObjectSendStream[types.JSONRPCMessage | Exception]
68-
]
71+
_read_stream_writers: dict[UUID, ReadStreamWriter]
6972

7073
def __init__(self, endpoint: str) -> None:
7174
"""
@@ -85,11 +88,11 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send):
8588
raise ValueError("connect_sse can only handle HTTP requests")
8689

8790
logger.debug("Setting up SSE connection")
88-
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
89-
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
91+
read_stream: ReadStream
92+
read_stream_writer: ReadStreamWriter
9093

91-
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
92-
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
94+
write_stream: WriteStream
95+
write_stream_reader: WriteStreamReader
9396

9497
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
9598
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)

src/mcp/server/stdio.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,14 @@ async def run_server():
2424

2525
import anyio
2626
import anyio.lowlevel
27-
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
2827

2928
import mcp.types as types
29+
from mcp.server.models import (
30+
ReadStream,
31+
ReadStreamWriter,
32+
WriteStream,
33+
WriteStreamReader,
34+
)
3035

3136

3237
@asynccontextmanager
@@ -47,11 +52,11 @@ async def stdio_server(
4752
if not stdout:
4853
stdout = anyio.wrap_file(TextIOWrapper(sys.stdout.buffer, encoding="utf-8"))
4954

50-
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
51-
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
55+
read_stream: ReadStream
56+
read_stream_writer: ReadStreamWriter
5257

53-
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
54-
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
58+
write_stream: WriteStream
59+
write_stream_reader: WriteStreamReader
5560

5661
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
5762
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)

src/mcp/server/websocket.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,16 @@
22
from contextlib import asynccontextmanager
33

44
import anyio
5-
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
65
from starlette.types import Receive, Scope, Send
76
from starlette.websockets import WebSocket
87

98
import mcp.types as types
9+
from mcp.server.models import (
10+
ReadStream,
11+
ReadStreamWriter,
12+
WriteStream,
13+
WriteStreamReader,
14+
)
1015

1116
logger = logging.getLogger(__name__)
1217

@@ -21,11 +26,11 @@ async def websocket_server(scope: Scope, receive: Receive, send: Send):
2126
websocket = WebSocket(scope, receive, send)
2227
await websocket.accept(subprotocol="mcp")
2328

24-
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
25-
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
29+
read_stream: ReadStream
30+
read_stream_writer: ReadStreamWriter
2631

27-
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
28-
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
32+
write_stream: WriteStream
33+
write_stream_reader: WriteStreamReader
2934

3035
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
3136
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)

0 commit comments

Comments
 (0)