Skip to content

Commit ab02c0d

Browse files
committed
fix process hanging on bad stdio connection params
1 parent 38229ba commit ab02c0d

File tree

2 files changed

+48
-12
lines changed

2 files changed

+48
-12
lines changed

Diff for: src/mcp/client/stdio/__init__.py

+26-5
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import anyio
88
import anyio.lowlevel
9+
from anyio.abc import Process
910
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
1011
from anyio.streams.text import TextReceiveStream
1112
from pydantic import BaseModel, Field
@@ -38,6 +39,10 @@
3839
)
3940

4041

42+
class ProcessTerminatedEarlyError(Exception):
43+
"""Raised when a process terminates unexpectedly."""
44+
45+
4146
def get_default_environment() -> dict[str, str]:
4247
"""
4348
Returns a default environment object including only environment variables deemed
@@ -110,7 +115,7 @@ async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stder
110115
command = _get_executable_command(server.command)
111116

112117
# Open process with stderr piped for capture
113-
process = await _create_platform_compatible_process(
118+
process: Process = await _create_platform_compatible_process(
114119
command=command,
115120
args=server.args,
116121
env=(
@@ -163,20 +168,36 @@ async def stdin_writer():
163168
except anyio.ClosedResourceError:
164169
await anyio.lowlevel.checkpoint()
165170

171+
process_error: str | None = None
172+
166173
async with (
167174
anyio.create_task_group() as tg,
168175
process,
169176
):
170177
tg.start_soon(stdout_reader)
171178
tg.start_soon(stdin_writer)
179+
# tg.start_soon(monitor_process, tg.cancel_scope)
172180
try:
173181
yield read_stream, write_stream
174182
finally:
175-
# Clean up process to prevent any dangling orphaned processes
176-
if sys.platform == "win32":
177-
await terminate_windows_process(process)
183+
await read_stream.aclose()
184+
await write_stream.aclose()
185+
await read_stream_writer.aclose()
186+
await write_stream_reader.aclose()
187+
188+
if process.returncode is not None and process.returncode != 0:
189+
process_error = f"Process exited with code {process.returncode}."
178190
else:
179-
process.terminate()
191+
# Clean up process to prevent any dangling orphaned processes
192+
if sys.platform == "win32":
193+
await terminate_windows_process(process)
194+
else:
195+
process.terminate()
196+
197+
if process_error:
198+
# Raise outside the task group so that the error is not wrapped in an
199+
# ExceptionGroup
200+
raise ProcessTerminatedEarlyError(process_error)
180201

181202

182203
def _get_executable_command(command: str) -> str:

Diff for: tests/client/test_stdio.py

+22-7
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@
33
import pytest
44
from anyio import fail_after
55

6-
from mcp.client.stdio import StdioServerParameters, stdio_client
6+
from mcp.client.stdio import (
7+
ProcessTerminatedEarlyError,
8+
StdioServerParameters,
9+
stdio_client,
10+
)
711
from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse
812

913
tee: str = shutil.which("tee") # type: ignore
@@ -51,9 +55,20 @@ async def test_stdio_client_bad_path():
5155
command="uv", args=["run", "non-existent-file.py"]
5256
)
5357

54-
try:
55-
with fail_after(1):
56-
async with stdio_client(server_parameters) as (read_stream, write_stream):
57-
pass
58-
except TimeoutError:
59-
pytest.fail("The connection hung.")
58+
with pytest.raises(ProcessTerminatedEarlyError):
59+
try:
60+
with fail_after(1):
61+
async with stdio_client(server_parameters) as (
62+
read_stream,
63+
_,
64+
):
65+
# Try waiting for read_stream so that we don't exit before the
66+
# process fails.
67+
async with read_stream:
68+
async for message in read_stream:
69+
if isinstance(message, Exception):
70+
raise message
71+
72+
pass
73+
except TimeoutError:
74+
pytest.fail("The connection hung.")

0 commit comments

Comments
 (0)