diff --git a/README.md b/README.md index 2611e25f0..1ef136b6b 100644 --- a/README.md +++ b/README.md @@ -796,6 +796,29 @@ async def main(): tool_result = await session.call_tool("echo", {"message": "hello"}) ``` +## Configuring the HTTP client + +It is possible to override the httpx client to customize for your needs. + +* proxy +* specific authentication needs mTLS and token +* advanced configs supported by httpx + +```python +import httpx +from mcp.client.streamable_http import streamablehttp_client + +http_client = httpx.AsyncClient( + base_url="http://someserver", + proxy="http://proxy.local", + headers={ + "Content": "application/json", + "Accept": "application/json,text/event-stream", + }, +) +streamable_client = streamablehttp_client("/v1/mcp", http_client=http_client) +``` + ### MCP Primitives The MCP protocol defines three core primitives that servers can implement: diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 3324dab5a..04b2d8ce2 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -427,6 +427,7 @@ async def streamablehttp_client( timeout: timedelta = timedelta(seconds=30), sse_read_timeout: timedelta = timedelta(seconds=60 * 5), terminate_on_close: bool = True, + http_client: httpx.AsyncClient | None = None, ) -> AsyncGenerator[ tuple[ MemoryObjectReceiveStream[SessionMessage | Exception], @@ -448,6 +449,12 @@ async def streamablehttp_client( - get_session_id_callback: Function to retrieve the current session ID """ transport = StreamableHTTPTransport(url, headers, timeout, sse_read_timeout) + client = http_client or create_mcp_http_client( + headers=transport.request_headers, + timeout=httpx.Timeout( + transport.timeout.seconds, read=transport.sse_read_timeout.seconds + ), + ) read_stream_writer, read_stream = anyio.create_memory_object_stream[ SessionMessage | Exception @@ -460,12 +467,7 @@ async def streamablehttp_client( try: logger.info(f"Connecting to StreamableHTTP endpoint: {url}") - async with create_mcp_http_client( - headers=transport.request_headers, - timeout=httpx.Timeout( - transport.timeout.seconds, read=transport.sse_read_timeout.seconds - ), - ) as client: + async with client: # Define callbacks that need access to tg def start_get_stream() -> None: tg.start_soon( diff --git a/tests/server/fastmcp/test_integration.py b/tests/server/fastmcp/test_integration.py index 79285ecb1..12d7e5d37 100644 --- a/tests/server/fastmcp/test_integration.py +++ b/tests/server/fastmcp/test_integration.py @@ -10,6 +10,7 @@ import time from collections.abc import Generator +import httpx import pytest import uvicorn from pydantic import AnyUrl @@ -460,6 +461,39 @@ async def test_fastmcp_streamable_http( assert tool_result.content[0].text == "Echo: hello" +@pytest.mark.anyio +async def test_fastmcp_streamable_http_with_custom_client( + streamable_http_server: None, http_server_url: str +) -> None: + """Test that FastMCP works with StreamableHTTP transport using custom HttpClient.""" + http_client = httpx.AsyncClient( + base_url=http_server_url, + follow_redirects=True, + headers={ + "Content-Type": "application/json", + "Accept": "application/json,text/event-stream", + }, + ) + # Connect to the server using StreamableHTTP + async with streamablehttp_client("/mcp", http_client=http_client) as ( + read_stream, + write_stream, + _, + ): + # Create a session using the client streams + async with ClientSession(read_stream, write_stream) as session: + # Test initialization + result = await session.initialize() + assert isinstance(result, InitializeResult) + assert result.serverInfo.name == "NoAuthServer" + + # Test that we can call tools without authentication + tool_result = await session.call_tool("echo", {"message": "hello"}) + assert len(tool_result.content) == 1 + assert isinstance(tool_result.content[0], TextContent) + assert tool_result.content[0].text == "Echo: hello" + + @pytest.mark.anyio async def test_fastmcp_stateless_streamable_http( stateless_http_server: None, stateless_http_server_url: str