Skip to content

Commit c4fb621

Browse files
committed
Clean up code
1 parent d3f0dea commit c4fb621

File tree

2 files changed

+17
-10
lines changed

2 files changed

+17
-10
lines changed

src/mcp/client/streamable_http.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
MCP_SESSION_ID = "mcp-session-id"
4343
LAST_EVENT_ID = "last-event-id"
4444
CONTENT_TYPE = "content-type"
45+
HEADER_CAPTURE = "[TESTING_HEADER_CAPTURE]"
4546
ACCEPT = "Accept"
4647

4748

@@ -275,7 +276,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:
275276
async def _is_testing_header_capture(self, response: httpx.Response) -> str | None:
276277
try:
277278
content = await response.aread()
278-
if content.decode().startswith("[TESTING_HEADER_CAPTURE]"):
279+
if content.decode().startswith(HEADER_CAPTURE):
279280
return content.decode()
280281
except Exception as _:
281282
return None
@@ -306,6 +307,10 @@ async def _handle_post_request(self, ctx: RequestContext) -> None:
306307
)
307308
return
308309

310+
# To test if headers are being forwarded correctly, in unit tests
311+
# we have a mock server that returns a 418 status code with the
312+
# HEADER_CAPTURE prefix. If the response has this status code
313+
# with the prefix, return the response content as part of the error message.
309314
if response.status_code == 418:
310315
test_error_message = await self._is_testing_header_capture(response)
311316
# If this is coming from the test case return the response content

tests/shared/test_streamable_http.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
import mcp.types as types
2727
from mcp.client.session import ClientSession
28-
from mcp.client.streamable_http import streamablehttp_client
28+
from mcp.client.streamable_http import HEADER_CAPTURE, streamablehttp_client
2929
from mcp.server import Server
3030
from mcp.server.streamable_http import (
3131
MCP_SESSION_ID_HEADER,
@@ -262,7 +262,7 @@ async def header_capture_wrapper(scope, receive, send):
262262

263263
# Return error response with headers in body
264264
response = Response(
265-
"[TESTING_HEADER_CAPTURE]:" + json.dumps({"headers": headers}),
265+
HEADER_CAPTURE + json.dumps({"headers": headers}),
266266
status_code=418,
267267
)
268268
await response(scope, receive, send)
@@ -279,7 +279,7 @@ async def header_capture_wrapper(scope, receive, send):
279279

280280

281281
def _get_captured_headrs(str) -> dict[str, str]:
282-
return json.loads(str.split("[TESTING_HEADER_CAPTURE]:")[1])["headers"]
282+
return json.loads(str.split(HEADER_CAPTURE)[1])["headers"]
283283

284284

285285
def run_server(
@@ -356,21 +356,23 @@ def _start_basic_server(
356356

357357
# Wait for server to be running
358358
max_attempts = 20
359-
for attempt in range(max_attempts):
359+
attempt = 0
360+
while attempt < max_attempts:
360361
try:
361362
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
362363
s.connect(("127.0.0.1", basic_server_port))
363364
break
364365
except ConnectionRefusedError:
365366
time.sleep(0.1)
367+
attempt += 1
366368
else:
367369
raise RuntimeError(f"Server failed to start after {max_attempts} attempts")
368370

369-
try:
370-
yield
371-
finally:
372-
proc.kill()
373-
proc.join(timeout=2)
371+
yield
372+
373+
# Clean up
374+
proc.kill()
375+
proc.join(timeout=2)
374376

375377

376378
@pytest.fixture

0 commit comments

Comments
 (0)