diff --git a/README.md b/README.md index c062f3a9..12d0dd39 100644 --- a/README.md +++ b/README.md @@ -521,16 +521,9 @@ if __name__ == "__main__": The SDK provides a high-level client interface for connecting to MCP servers: ```python -from mcp import ClientSession, StdioServerParameters, types +from mcp import ClientSession, types from mcp.client.stdio import stdio_client -# Create server parameters for stdio connection -server_params = StdioServerParameters( - command="python", # Executable - args=["example_server.py"], # Optional command line arguments - env=None, # Optional environment variables -) - # Optional: create a sampling callback async def handle_sampling_message( @@ -548,7 +541,12 @@ async def handle_sampling_message( async def run(): - async with stdio_client(server_params) as (read, write): + # Connect to server using stdio transport + async with stdio_client( + command="python", # Executable + args=["example_server.py"], # Optional command line arguments + env=None, # Optional environment variables + ) as (read, write): async with ClientSession( read, write, sampling_callback=handle_sampling_message ) as session: diff --git a/src/mcp/__init__.py b/src/mcp/__init__.py index 0d3c372c..474f87c3 100644 --- a/src/mcp/__init__.py +++ b/src/mcp/__init__.py @@ -1,5 +1,5 @@ from .client.session import ClientSession -from .client.stdio import StdioServerParameters, stdio_client +from .client.stdio import stdio_client from .server.session import ServerSession from .server.stdio import stdio_server from .shared.exceptions import McpError @@ -101,7 +101,6 @@ "ServerResult", "ServerSession", "SetLevelRequest", - "StdioServerParameters", "StopReason", "SubscribeRequest", "Tool", diff --git a/src/mcp/client/stdio.py b/src/mcp/client/stdio.py index 2eaa3475..089b80a9 100644 --- a/src/mcp/client/stdio.py +++ b/src/mcp/client/stdio.py @@ -7,7 +7,6 @@ import anyio.lowlevel from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from anyio.streams.text import TextReceiveStream -from pydantic import BaseModel, Field import mcp.types as types @@ -52,41 +51,19 @@ def get_default_environment() -> dict[str, str]: return env -class StdioServerParameters(BaseModel): - command: str - """The executable to run to start the server.""" - - args: list[str] = Field(default_factory=list) - """Command line arguments to pass to the executable.""" - - env: dict[str, str] | None = None - """ - The environment to use when spawning the process. - - If not specified, the result of get_default_environment() will be used. - """ - - encoding: str = "utf-8" - """ - The text encoding used when sending/receiving messages to the server - - defaults to utf-8 - """ - - encoding_error_handler: Literal["strict", "ignore", "replace"] = "strict" - """ - The text encoding error handler. - - See https://docs.python.org/3/library/codecs.html#codec-base-classes for - explanations of possible values - """ - - @asynccontextmanager -async def stdio_client(server: StdioServerParameters): +async def stdio_client( + command: str, + args: list[str] | None = None, + env: dict[str, str] | None = None, + encoding: str = "utf-8", + encoding_error_handler: Literal["strict", "ignore", "replace"] = "strict", +): """ Client transport for stdio: this will connect to a server by spawning a process and communicating with it over stdin/stdout. + + """ read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] @@ -98,8 +75,8 @@ async def stdio_client(server: StdioServerParameters): write_stream, write_stream_reader = anyio.create_memory_object_stream(0) process = await anyio.open_process( - [server.command, *server.args], - env=server.env if server.env is not None else get_default_environment(), + [command] + (args or []), + env=env if env is not None else get_default_environment(), stderr=sys.stderr, ) @@ -111,8 +88,8 @@ async def stdout_reader(): buffer = "" async for chunk in TextReceiveStream( process.stdout, - encoding=server.encoding, - errors=server.encoding_error_handler, + encoding=encoding, + errors=encoding_error_handler, ): lines = (buffer + chunk).split("\n") buffer = lines.pop() @@ -137,8 +114,8 @@ async def stdin_writer(): json = message.model_dump_json(by_alias=True, exclude_none=True) await process.stdin.send( (json + "\n").encode( - encoding=server.encoding, - errors=server.encoding_error_handler, + encoding=encoding, + errors=encoding_error_handler, ) ) except anyio.ClosedResourceError: diff --git a/tests/client/test_stdio.py b/tests/client/test_stdio.py index 95747ffd..ff559362 100644 --- a/tests/client/test_stdio.py +++ b/tests/client/test_stdio.py @@ -2,7 +2,7 @@ import pytest -from mcp.client.stdio import StdioServerParameters, stdio_client +from mcp.client.stdio import stdio_client from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse tee: str = shutil.which("tee") # type: ignore @@ -11,9 +11,7 @@ @pytest.mark.anyio @pytest.mark.skipif(tee is None, reason="could not find tee command") async def test_stdio_client(): - server_parameters = StdioServerParameters(command=tee) - - async with stdio_client(server_parameters) as (read_stream, write_stream): + async with stdio_client(command=tee) as (read_stream, write_stream): # Test sending and receiving messages messages = [ JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")),