Skip to content

Commit 1691b90

Browse files
Merge pull request #179 from smithery-ai/patch-1
Create Client websocket.py
2 parents e756315 + fb7d0c8 commit 1691b90

File tree

4 files changed

+377
-1
lines changed

4 files changed

+377
-1
lines changed

Diff for: pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ dependencies = [
3535
[project.optional-dependencies]
3636
rich = ["rich>=13.9.4"]
3737
cli = ["typer>=0.12.4", "python-dotenv>=1.0.0"]
38+
ws = ["websockets>=15.0.1"]
3839

3940
[project.scripts]
4041
mcp = "mcp.cli:app [cli]"

Diff for: src/mcp/client/websocket.py

+82
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import json
2+
import logging
3+
from contextlib import asynccontextmanager
4+
from typing import AsyncGenerator
5+
6+
import anyio
7+
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
8+
from pydantic import ValidationError
9+
from websockets.asyncio.client import connect as ws_connect
10+
from websockets.typing import Subprotocol
11+
12+
import mcp.types as types
13+
14+
logger = logging.getLogger(__name__)
15+
16+
17+
@asynccontextmanager
18+
async def websocket_client(url: str) -> AsyncGenerator[
19+
tuple[
20+
MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
21+
MemoryObjectSendStream[types.JSONRPCMessage],
22+
],
23+
None,
24+
]:
25+
"""
26+
WebSocket client transport for MCP, symmetrical to the server version.
27+
28+
Connects to 'url' using the 'mcp' subprotocol, then yields:
29+
(read_stream, write_stream)
30+
31+
- read_stream: As you read from this stream, you'll receive either valid
32+
JSONRPCMessage objects or Exception objects (when validation fails).
33+
- write_stream: Write JSONRPCMessage objects to this stream to send them
34+
over the WebSocket to the server.
35+
"""
36+
37+
# Create two in-memory streams:
38+
# - One for incoming messages (read_stream, written by ws_reader)
39+
# - One for outgoing messages (write_stream, read by ws_writer)
40+
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
41+
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
42+
43+
# Connect using websockets, requesting the "mcp" subprotocol
44+
async with ws_connect(url, subprotocols=[Subprotocol("mcp")]) as ws:
45+
46+
async def ws_reader():
47+
"""
48+
Reads text messages from the WebSocket, parses them as JSON-RPC messages,
49+
and sends them into read_stream_writer.
50+
"""
51+
async with read_stream_writer:
52+
async for raw_text in ws:
53+
try:
54+
message = types.JSONRPCMessage.model_validate_json(raw_text)
55+
await read_stream_writer.send(message)
56+
except ValidationError as exc:
57+
# If JSON parse or model validation fails, send the exception
58+
await read_stream_writer.send(exc)
59+
60+
async def ws_writer():
61+
"""
62+
Reads JSON-RPC messages from write_stream_reader and
63+
sends them to the server.
64+
"""
65+
async with write_stream_reader:
66+
async for message in write_stream_reader:
67+
# Convert to a dict, then to JSON
68+
msg_dict = message.model_dump(
69+
by_alias=True, mode="json", exclude_none=True
70+
)
71+
await ws.send(json.dumps(msg_dict))
72+
73+
async with anyio.create_task_group() as tg:
74+
# Start reader and writer tasks
75+
tg.start_soon(ws_reader)
76+
tg.start_soon(ws_writer)
77+
78+
# Yield the receive/send streams
79+
yield (read_stream, write_stream)
80+
81+
# Once the caller's 'async with' block exits, we shut down
82+
tg.cancel_scope.cancel()

Diff for: tests/shared/test_ws.py

+230
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
import multiprocessing
2+
import socket
3+
import time
4+
from typing import AsyncGenerator, Generator
5+
6+
import anyio
7+
import pytest
8+
import uvicorn
9+
from pydantic import AnyUrl
10+
from starlette.applications import Starlette
11+
from starlette.routing import WebSocketRoute
12+
13+
from mcp.client.session import ClientSession
14+
from mcp.client.websocket import websocket_client
15+
from mcp.server import Server
16+
from mcp.server.websocket import websocket_server
17+
from mcp.shared.exceptions import McpError
18+
from mcp.types import (
19+
EmptyResult,
20+
ErrorData,
21+
InitializeResult,
22+
ReadResourceResult,
23+
TextContent,
24+
TextResourceContents,
25+
Tool,
26+
)
27+
28+
SERVER_NAME = "test_server_for_WS"
29+
30+
31+
@pytest.fixture
32+
def server_port() -> int:
33+
with socket.socket() as s:
34+
s.bind(("127.0.0.1", 0))
35+
return s.getsockname()[1]
36+
37+
38+
@pytest.fixture
39+
def server_url(server_port: int) -> str:
40+
return f"ws://127.0.0.1:{server_port}"
41+
42+
43+
# Test server implementation
44+
class ServerTest(Server):
45+
def __init__(self):
46+
super().__init__(SERVER_NAME)
47+
48+
@self.read_resource()
49+
async def handle_read_resource(uri: AnyUrl) -> str | bytes:
50+
if uri.scheme == "foobar":
51+
return f"Read {uri.host}"
52+
elif uri.scheme == "slow":
53+
# Simulate a slow resource
54+
await anyio.sleep(2.0)
55+
return f"Slow response from {uri.host}"
56+
57+
raise McpError(
58+
error=ErrorData(
59+
code=404, message="OOPS! no resource with that URI was found"
60+
)
61+
)
62+
63+
@self.list_tools()
64+
async def handle_list_tools() -> list[Tool]:
65+
return [
66+
Tool(
67+
name="test_tool",
68+
description="A test tool",
69+
inputSchema={"type": "object", "properties": {}},
70+
)
71+
]
72+
73+
@self.call_tool()
74+
async def handle_call_tool(name: str, args: dict) -> list[TextContent]:
75+
return [TextContent(type="text", text=f"Called {name}")]
76+
77+
78+
# Test fixtures
79+
def make_server_app() -> Starlette:
80+
"""Create test Starlette app with WebSocket transport"""
81+
server = ServerTest()
82+
83+
async def handle_ws(websocket):
84+
async with websocket_server(
85+
websocket.scope, websocket.receive, websocket.send
86+
) as streams:
87+
await server.run(
88+
streams[0], streams[1], server.create_initialization_options()
89+
)
90+
91+
app = Starlette(
92+
routes=[
93+
WebSocketRoute("/ws", endpoint=handle_ws),
94+
]
95+
)
96+
97+
return app
98+
99+
100+
def run_server(server_port: int) -> None:
101+
app = make_server_app()
102+
server = uvicorn.Server(
103+
config=uvicorn.Config(
104+
app=app, host="127.0.0.1", port=server_port, log_level="error"
105+
)
106+
)
107+
print(f"starting server on {server_port}")
108+
server.run()
109+
110+
# Give server time to start
111+
while not server.started:
112+
print("waiting for server to start")
113+
time.sleep(0.5)
114+
115+
116+
@pytest.fixture()
117+
def server(server_port: int) -> Generator[None, None, None]:
118+
proc = multiprocessing.Process(
119+
target=run_server, kwargs={"server_port": server_port}, daemon=True
120+
)
121+
print("starting process")
122+
proc.start()
123+
124+
# Wait for server to be running
125+
max_attempts = 20
126+
attempt = 0
127+
print("waiting for server to start")
128+
while attempt < max_attempts:
129+
try:
130+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
131+
s.connect(("127.0.0.1", server_port))
132+
break
133+
except ConnectionRefusedError:
134+
time.sleep(0.1)
135+
attempt += 1
136+
else:
137+
raise RuntimeError(
138+
"Server failed to start after {} attempts".format(max_attempts)
139+
)
140+
141+
yield
142+
143+
print("killing server")
144+
# Signal the server to stop
145+
proc.kill()
146+
proc.join(timeout=2)
147+
if proc.is_alive():
148+
print("server process failed to terminate")
149+
150+
151+
@pytest.fixture()
152+
async def initialized_ws_client_session(
153+
server, server_url: str
154+
) -> AsyncGenerator[ClientSession, None]:
155+
"""Create and initialize a WebSocket client session"""
156+
async with websocket_client(server_url + "/ws") as streams:
157+
async with ClientSession(*streams) as session:
158+
# Test initialization
159+
result = await session.initialize()
160+
assert isinstance(result, InitializeResult)
161+
assert result.serverInfo.name == SERVER_NAME
162+
163+
# Test ping
164+
ping_result = await session.send_ping()
165+
assert isinstance(ping_result, EmptyResult)
166+
167+
yield session
168+
169+
170+
# Tests
171+
@pytest.mark.anyio
172+
async def test_ws_client_basic_connection(server: None, server_url: str) -> None:
173+
"""Test the WebSocket connection establishment"""
174+
async with websocket_client(server_url + "/ws") as streams:
175+
async with ClientSession(*streams) as session:
176+
# Test initialization
177+
result = await session.initialize()
178+
assert isinstance(result, InitializeResult)
179+
assert result.serverInfo.name == SERVER_NAME
180+
181+
# Test ping
182+
ping_result = await session.send_ping()
183+
assert isinstance(ping_result, EmptyResult)
184+
185+
186+
@pytest.mark.anyio
187+
async def test_ws_client_happy_request_and_response(
188+
initialized_ws_client_session: ClientSession,
189+
) -> None:
190+
"""Test a successful request and response via WebSocket"""
191+
result = await initialized_ws_client_session.read_resource(
192+
AnyUrl("foobar://example")
193+
)
194+
assert isinstance(result, ReadResourceResult)
195+
assert isinstance(result.contents, list)
196+
assert len(result.contents) > 0
197+
assert isinstance(result.contents[0], TextResourceContents)
198+
assert result.contents[0].text == "Read example"
199+
200+
201+
@pytest.mark.anyio
202+
async def test_ws_client_exception_handling(
203+
initialized_ws_client_session: ClientSession,
204+
) -> None:
205+
"""Test exception handling in WebSocket communication"""
206+
with pytest.raises(McpError) as exc_info:
207+
await initialized_ws_client_session.read_resource(AnyUrl("unknown://example"))
208+
assert exc_info.value.error.code == 404
209+
210+
211+
@pytest.mark.anyio
212+
async def test_ws_client_timeout(
213+
initialized_ws_client_session: ClientSession,
214+
) -> None:
215+
"""Test timeout handling in WebSocket communication"""
216+
# Set a very short timeout to trigger a timeout exception
217+
with pytest.raises(TimeoutError):
218+
with anyio.fail_after(0.1): # 100ms timeout
219+
await initialized_ws_client_session.read_resource(AnyUrl("slow://example"))
220+
221+
# Now test that we can still use the session after a timeout
222+
with anyio.fail_after(5): # Longer timeout to allow completion
223+
result = await initialized_ws_client_session.read_resource(
224+
AnyUrl("foobar://example")
225+
)
226+
assert isinstance(result, ReadResourceResult)
227+
assert isinstance(result.contents, list)
228+
assert len(result.contents) > 0
229+
assert isinstance(result.contents[0], TextResourceContents)
230+
assert result.contents[0].text == "Read example"

0 commit comments

Comments
 (0)