Skip to content

Commit 4c71c61

Browse files
Merge pull request #151 from modelcontextprotocol/merrill/tests-for-sse
add minimal tests for the SSE transport
2 parents 62a0af6 + d01d49e commit 4c71c61

File tree

1 file changed

+254
-0
lines changed

1 file changed

+254
-0
lines changed

tests/shared/test_sse.py

+254
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,254 @@
1+
import multiprocessing
2+
import socket
3+
import time
4+
from typing import AsyncGenerator, Generator
5+
6+
import anyio
7+
import httpx
8+
import pytest
9+
import uvicorn
10+
from pydantic import AnyUrl
11+
from starlette.applications import Starlette
12+
from starlette.requests import Request
13+
from starlette.routing import Mount, Route
14+
15+
from mcp.client.session import ClientSession
16+
from mcp.client.sse import sse_client
17+
from mcp.server import Server
18+
from mcp.server.sse import SseServerTransport
19+
from mcp.shared.exceptions import McpError
20+
from mcp.types import (
21+
EmptyResult,
22+
ErrorData,
23+
InitializeResult,
24+
ReadResourceResult,
25+
TextContent,
26+
TextResourceContents,
27+
Tool,
28+
)
29+
30+
SERVER_NAME = "test_server_for_SSE"
31+
32+
33+
@pytest.fixture
34+
def server_port() -> int:
35+
with socket.socket() as s:
36+
s.bind(("127.0.0.1", 0))
37+
return s.getsockname()[1]
38+
39+
40+
@pytest.fixture
41+
def server_url(server_port: int) -> str:
42+
return f"http://127.0.0.1:{server_port}"
43+
44+
45+
# Test server implementation
46+
class TestServer(Server):
47+
def __init__(self):
48+
super().__init__(SERVER_NAME)
49+
50+
@self.read_resource()
51+
async def handle_read_resource(uri: AnyUrl) -> str | bytes:
52+
if uri.scheme == "foobar":
53+
return f"Read {uri.host}"
54+
elif uri.scheme == "slow":
55+
# Simulate a slow resource
56+
await anyio.sleep(2.0)
57+
return f"Slow response from {uri.host}"
58+
59+
raise McpError(
60+
error=ErrorData(
61+
code=404, message="OOPS! no resource with that URI was found"
62+
)
63+
)
64+
65+
@self.list_tools()
66+
async def handle_list_tools() -> list[Tool]:
67+
return [
68+
Tool(
69+
name="test_tool",
70+
description="A test tool",
71+
inputSchema={"type": "object", "properties": {}},
72+
)
73+
]
74+
75+
@self.call_tool()
76+
async def handle_call_tool(name: str, args: dict) -> list[TextContent]:
77+
return [TextContent(type="text", text=f"Called {name}")]
78+
79+
80+
# Test fixtures
81+
def make_server_app() -> Starlette:
82+
"""Create test Starlette app with SSE transport"""
83+
sse = SseServerTransport("/messages/")
84+
server = TestServer()
85+
86+
async def handle_sse(request: Request) -> None:
87+
async with sse.connect_sse(
88+
request.scope, request.receive, request._send
89+
) as streams:
90+
await server.run(
91+
streams[0], streams[1], server.create_initialization_options()
92+
)
93+
94+
app = Starlette(
95+
routes=[
96+
Route("/sse", endpoint=handle_sse),
97+
Mount("/messages/", app=sse.handle_post_message),
98+
]
99+
)
100+
101+
return app
102+
103+
104+
def run_server(server_port: int) -> None:
105+
app = make_server_app()
106+
server = uvicorn.Server(
107+
config=uvicorn.Config(
108+
app=app, host="127.0.0.1", port=server_port, log_level="error"
109+
)
110+
)
111+
print(f"starting server on {server_port}")
112+
server.run()
113+
114+
# Give server time to start
115+
while not server.started:
116+
print("waiting for server to start")
117+
time.sleep(0.5)
118+
119+
120+
@pytest.fixture()
121+
def server(server_port: int) -> Generator[None, None, None]:
122+
proc = multiprocessing.Process(
123+
target=run_server, kwargs={"server_port": server_port}, daemon=True
124+
)
125+
print("starting process")
126+
proc.start()
127+
128+
# Wait for server to be running
129+
max_attempts = 20
130+
attempt = 0
131+
print("waiting for server to start")
132+
while attempt < max_attempts:
133+
try:
134+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
135+
s.connect(("127.0.0.1", server_port))
136+
break
137+
except ConnectionRefusedError:
138+
time.sleep(0.1)
139+
attempt += 1
140+
else:
141+
raise RuntimeError(
142+
"Server failed to start after {} attempts".format(max_attempts)
143+
)
144+
145+
yield
146+
147+
print("killing server")
148+
# Signal the server to stop
149+
proc.kill()
150+
proc.join(timeout=2)
151+
if proc.is_alive():
152+
print("server process failed to terminate")
153+
154+
155+
@pytest.fixture()
156+
async def http_client(server, server_url) -> AsyncGenerator[httpx.AsyncClient, None]:
157+
"""Create test client"""
158+
async with httpx.AsyncClient(base_url=server_url) as client:
159+
yield client
160+
161+
162+
# Tests
163+
@pytest.mark.anyio
164+
async def test_raw_sse_connection(http_client: httpx.AsyncClient) -> None:
165+
"""Test the SSE connection establishment simply with an HTTP client."""
166+
async with anyio.create_task_group():
167+
async def connection_test() -> None:
168+
async with http_client.stream("GET", "/sse") as response:
169+
assert response.status_code == 200
170+
assert (
171+
response.headers["content-type"]
172+
== "text/event-stream; charset=utf-8"
173+
)
174+
175+
line_number = 0
176+
async for line in response.aiter_lines():
177+
if line_number == 0:
178+
assert line == "event: endpoint"
179+
elif line_number == 1:
180+
assert line.startswith("data: /messages/?session_id=")
181+
else:
182+
return
183+
line_number += 1
184+
185+
# Add timeout to prevent test from hanging if it fails
186+
with anyio.fail_after(3):
187+
await connection_test()
188+
189+
190+
@pytest.mark.anyio
191+
async def test_sse_client_basic_connection(server: None, server_url: str) -> None:
192+
async with sse_client(server_url + "/sse") as streams:
193+
async with ClientSession(*streams) as session:
194+
# Test initialization
195+
result = await session.initialize()
196+
assert isinstance(result, InitializeResult)
197+
assert result.serverInfo.name == SERVER_NAME
198+
199+
# Test ping
200+
ping_result = await session.send_ping()
201+
assert isinstance(ping_result, EmptyResult)
202+
203+
204+
@pytest.fixture
205+
async def initialized_sse_client_session(
206+
server, server_url: str
207+
) -> AsyncGenerator[ClientSession, None]:
208+
async with sse_client(server_url + "/sse", sse_read_timeout=0.5) as streams:
209+
async with ClientSession(*streams) as session:
210+
await session.initialize()
211+
yield session
212+
213+
214+
215+
@pytest.mark.anyio
216+
async def test_sse_client_happy_request_and_response(
217+
initialized_sse_client_session: ClientSession,
218+
) -> None:
219+
session = initialized_sse_client_session
220+
response = await session.read_resource(uri=AnyUrl("foobar://should-work"))
221+
assert len(response.contents) == 1
222+
assert isinstance(response.contents[0], TextResourceContents)
223+
assert response.contents[0].text == "Read should-work"
224+
225+
226+
@pytest.mark.anyio
227+
async def test_sse_client_exception_handling(
228+
initialized_sse_client_session: ClientSession,
229+
) -> None:
230+
session = initialized_sse_client_session
231+
with pytest.raises(McpError, match="OOPS! no resource with that URI was found"):
232+
await session.read_resource(uri=AnyUrl("xxx://will-not-work"))
233+
234+
235+
@pytest.mark.anyio
236+
@pytest.mark.skip(
237+
"this test highlights a possible bug in SSE read timeout exception handling"
238+
)
239+
async def test_sse_client_timeout(
240+
initialized_sse_client_session: ClientSession,
241+
) -> None:
242+
session = initialized_sse_client_session
243+
244+
# sanity check that normal, fast responses are working
245+
response = await session.read_resource(uri=AnyUrl("foobar://1"))
246+
assert isinstance(response, ReadResourceResult)
247+
248+
with anyio.move_on_after(3):
249+
with pytest.raises(McpError, match="Read timed out"):
250+
response = await session.read_resource(uri=AnyUrl("slow://2"))
251+
# we should receive an error here
252+
return
253+
254+
pytest.fail("the client should have timed out and returned an error already")

0 commit comments

Comments
 (0)