Skip to content

Commit 97201cc

Browse files
authored
Strict types on the client side (#285)
1 parent 7196604 commit 97201cc

File tree

6 files changed

+23
-14
lines changed

6 files changed

+23
-14
lines changed

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ venvPath = "."
7777
venv = ".venv"
7878
strict = [
7979
"src/mcp/server/fastmcp/tools/base.py",
80+
"src/mcp/client/*.py"
8081
]
8182

8283
[tool.ruff.lint]

src/mcp/client/__main__.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55
from urllib.parse import urlparse
66

77
import anyio
8+
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
89

910
from mcp.client.session import ClientSession
1011
from mcp.client.sse import sse_client
1112
from mcp.client.stdio import StdioServerParameters, stdio_client
13+
from mcp.types import JSONRPCMessage
1214

1315
if not sys.warnoptions:
1416
import warnings
@@ -29,7 +31,10 @@ async def receive_loop(session: ClientSession):
2931
logger.info("Received message from server: %s", message)
3032

3133

32-
async def run_session(read_stream, write_stream):
34+
async def run_session(
35+
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
36+
write_stream: MemoryObjectSendStream[JSONRPCMessage],
37+
):
3338
async with (
3439
ClientSession(read_stream, write_stream) as session,
3540
anyio.create_task_group() as tg,

src/mcp/client/session.py

+6-12
Original file line numberDiff line numberDiff line change
@@ -76,18 +76,12 @@ def __init__(
7676
self._list_roots_callback = list_roots_callback or _default_list_roots_callback
7777

7878
async def initialize(self) -> types.InitializeResult:
79-
sampling = (
80-
types.SamplingCapability() if self._sampling_callback is not None else None
81-
)
82-
roots = (
83-
types.RootsCapability(
84-
# TODO: Should this be based on whether we
85-
# _will_ send notifications, or only whether
86-
# they're supported?
87-
listChanged=True,
88-
)
89-
if self._list_roots_callback is not None
90-
else None
79+
sampling = types.SamplingCapability()
80+
roots = types.RootsCapability(
81+
# TODO: Should this be based on whether we
82+
# _will_ send notifications, or only whether
83+
# they're supported?
84+
listChanged=True,
9185
)
9286

9387
result = await self.send_request(

src/mcp/client/sse.py

+4
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,10 @@ async def sse_reader(
9898
continue
9999

100100
await read_stream_writer.send(message)
101+
case _:
102+
logger.warning(
103+
f"Unknown SSE event: {sse.event}"
104+
)
101105
except Exception as exc:
102106
logger.error(f"Error in sse_reader: {exc}")
103107
await read_stream_writer.send(exc)

src/mcp/client/websocket.py

+5
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@ async def websocket_client(
3939
# Create two in-memory streams:
4040
# - One for incoming messages (read_stream, written by ws_reader)
4141
# - One for outgoing messages (write_stream, read by ws_writer)
42+
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
43+
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
44+
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
45+
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
46+
4247
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
4348
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
4449

src/mcp/shared/version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from mcp.types import LATEST_PROTOCOL_VERSION
22

3-
SUPPORTED_PROTOCOL_VERSIONS = [1, LATEST_PROTOCOL_VERSION]
3+
SUPPORTED_PROTOCOL_VERSIONS: tuple[int, str] = (1, LATEST_PROTOCOL_VERSION)

0 commit comments

Comments
 (0)