From 25600d4fd902a991c00a75ef8fd2dbd83aebf892 Mon Sep 17 00:00:00 2001 From: Joshua Newman Date: Wed, 2 Apr 2025 12:08:13 -0700 Subject: [PATCH 1/5] streamable --- src/mcp/client/session.py | 6 +- src/mcp/client/streamable.py | 178 +++++++++++++++++++++++++++++++++++ 2 files changed, 183 insertions(+), 1 deletion(-) create mode 100644 src/mcp/client/streamable.py diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 65d5e11e..46a08e00 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -97,6 +97,7 @@ def __init__( list_roots_callback: ListRootsFnT | None = None, logging_callback: LoggingFnT | None = None, message_handler: MessageHandlerFnT | None = None, + supported_protocol_versions: tuple[str | int, ...] | None = None, ) -> None: super().__init__( read_stream, @@ -109,6 +110,9 @@ def __init__( self._list_roots_callback = list_roots_callback or _default_list_roots_callback self._logging_callback = logging_callback or _default_logging_callback self._message_handler = message_handler or _default_message_handler + self._supported_protocol_versions = ( + supported_protocol_versions or SUPPORTED_PROTOCOL_VERSIONS + ) async def initialize(self) -> types.InitializeResult: sampling = types.SamplingCapability() @@ -137,7 +141,7 @@ async def initialize(self) -> types.InitializeResult: types.InitializeResult, ) - if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS: + if result.protocolVersion not in self._supported_protocol_versions: raise RuntimeError( "Unsupported protocol version from the server: " f"{result.protocolVersion}" diff --git a/src/mcp/client/streamable.py b/src/mcp/client/streamable.py new file mode 100644 index 00000000..46c50796 --- /dev/null +++ b/src/mcp/client/streamable.py @@ -0,0 +1,178 @@ +import logging +from contextlib import asynccontextmanager + +import anyio +import httpx +from httpx_sse import EventSource +from pydantic import TypeAdapter + +import mcp.types as types +from mcp.client.sse import sse_client + +logger = logging.getLogger(__name__) + +STREAMABLE_PROTOCOL_VERSION = "2025-03-26" +SUPPORTED_PROTOCOL_VERSIONS: tuple[str, ...] = ( + types.LATEST_PROTOCOL_VERSION, + STREAMABLE_PROTOCOL_VERSION, +) + + +@asynccontextmanager +async def streamable_client( + url: str, + timeout: float = 5, +): + """ + Client transport for streamable HTTP, with fallback to SSE. + """ + if await _is_old_sse_server(url, timeout): + async with sse_client(url) as (read_stream, write_stream): + yield read_stream, write_stream + return + + read_stream_writer, read_stream = anyio.create_memory_object_stream[ + types.JSONRPCMessage | Exception + ](0) + write_stream, write_stream_reader = anyio.create_memory_object_stream[ + types.JSONRPCMessage + ](0) + + async def handle_response(text: str) -> None: + items = _maybe_list_adapter.validate_json(text) + if isinstance(items, types.JSONRPCMessage): + items = [items] + for item in items: + await read_stream_writer.send(item) + + headers: tuple[tuple[str, str], ...] = () + + async with anyio.create_task_group() as tg: + try: + async with httpx.AsyncClient(timeout=timeout) as client: + + async def sse_reader(event_source: EventSource): + try: + async for sse in event_source.aiter_sse(): + logger.debug(f"Received SSE event: {sse.event}") + match sse.event: + case "message": + try: + await handle_response(sse.data) + logger.debug( + f"Received server message: {sse.data}" + ) + except Exception as exc: + logger.error( + f"Error parsing server message: {exc}" + ) + await read_stream_writer.send(exc) + continue + case _: + logger.warning(f"Unknown SSE event: {sse.event}") + except Exception as exc: + logger.error(f"Error in sse_reader: {exc}") + await read_stream_writer.send(exc) + finally: + await read_stream_writer.aclose() + + async def post_writer(): + nonlocal headers + try: + async with write_stream_reader: + async for message in write_stream_reader: + logger.debug(f"Sending client message: {message}") + response = await client.post( + url, + json=message.model_dump( + by_alias=True, + mode="json", + exclude_none=True, + ), + headers=( + ("accept", "application/json"), + ("accept", "text/event-stream"), + *headers, + ), + ) + logger.debug( + f"response {url=} content-type={response.headers.get("content-type")} body={response.text}" + ) + + response.raise_for_status() + match response.headers.get("mcp-session-id"): + case str() as session_id: + headers = (("mcp-session-id", session_id),) + case _: + pass + + match response.headers.get("content-type"): + case "text/event-stream": + await sse_reader(EventSource(response)) + case "application/json": + await handle_response(response.text) + case None: + pass + case unknown: + logger.warning( + f"Unknown content type: {unknown}" + ) + + logger.debug( + "Client message sent successfully: " + f"{response.status_code}" + ) + except Exception as exc: + logger.error(f"Error in post_writer: {exc}", exc_info=True) + finally: + await write_stream.aclose() + + tg.start_soon(post_writer) + + try: + yield read_stream, write_stream + finally: + tg.cancel_scope.cancel() + finally: + await read_stream_writer.aclose() + await write_stream.aclose() + + +_maybe_list_adapter: TypeAdapter[types.JSONRPCMessage | list[types.JSONRPCMessage]] = ( + TypeAdapter(types.JSONRPCMessage | list[types.JSONRPCMessage]) +) + + +async def _is_old_sse_server(url: str, timeout: float) -> bool: + """ + Test whether this is an old SSE MCP server. + + See: https://spec.modelcontextprotocol.io/specification/2025-03-26/basic/transports/#backwards-compatibility + """ + async with httpx.AsyncClient(timeout=timeout) as client: + test_initialize_request = types.InitializeRequest( + method="initialize", + params=types.InitializeRequestParams( + protocolVersion=STREAMABLE_PROTOCOL_VERSION, + capabilities=types.ClientCapabilities(), + clientInfo=types.Implementation(name="mcp", version="0.1.0"), + ), + ) + response = await client.post( + url, + json=types.JSONRPCRequest( + jsonrpc="2.0", + id=1, + method=test_initialize_request.method, + params=test_initialize_request.params.model_dump( + by_alias=True, mode="json", exclude_none=True + ), + ).model_dump(by_alias=True, mode="json", exclude_none=True), + headers=( + ("accept", "application/json"), + ("accept", "text/event-stream"), + ), + ) + if 400 <= response.status_code < 500: + return True + return False From 178571b944f6ab6fed9b4d3762bd6b9fbe70087f Mon Sep 17 00:00:00 2001 From: Joshua Newman Date: Wed, 2 Apr 2025 16:51:32 -0700 Subject: [PATCH 2/5] fmt --- src/mcp/client/streamable.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/mcp/client/streamable.py b/src/mcp/client/streamable.py index 46c50796..1bac3e91 100644 --- a/src/mcp/client/streamable.py +++ b/src/mcp/client/streamable.py @@ -95,8 +95,11 @@ async def post_writer(): *headers, ), ) + content_type = response.headers.get("content-type") logger.debug( - f"response {url=} content-type={response.headers.get("content-type")} body={response.text}" + f"response {url=} " + f"content-type={content_type} " + f"body={response.text}" ) response.raise_for_status() From e4b7f8d50d22b0b05d2346c3cb63bb82c09ad46d Mon Sep 17 00:00:00 2001 From: Joshua Newman Date: Thu, 3 Apr 2025 16:30:48 -0700 Subject: [PATCH 3/5] add headers --- src/mcp/client/streamable.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/mcp/client/streamable.py b/src/mcp/client/streamable.py index 1bac3e91..2d6af97d 100644 --- a/src/mcp/client/streamable.py +++ b/src/mcp/client/streamable.py @@ -1,5 +1,6 @@ import logging from contextlib import asynccontextmanager +from typing import Any import anyio import httpx @@ -21,6 +22,7 @@ @asynccontextmanager async def streamable_client( url: str, + headers: dict[str, Any] | None = None, timeout: float = 5, ): """ @@ -45,7 +47,7 @@ async def handle_response(text: str) -> None: for item in items: await read_stream_writer.send(item) - headers: tuple[tuple[str, str], ...] = () + session_headers = headers.copy() if headers else {} async with anyio.create_task_group() as tg: try: @@ -92,7 +94,7 @@ async def post_writer(): headers=( ("accept", "application/json"), ("accept", "text/event-stream"), - *headers, + *session_headers.items(), ), ) content_type = response.headers.get("content-type") @@ -105,7 +107,7 @@ async def post_writer(): response.raise_for_status() match response.headers.get("mcp-session-id"): case str() as session_id: - headers = (("mcp-session-id", session_id),) + session_headers["mcp-session-id"] = session_id case _: pass From 8b2cbacd4f63678d64b2a1453dcd81d52ba1d752 Mon Sep 17 00:00:00 2001 From: Joshua Newman Date: Thu, 3 Apr 2025 16:39:28 -0700 Subject: [PATCH 4/5] forward headers to old transport --- src/mcp/client/streamable.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcp/client/streamable.py b/src/mcp/client/streamable.py index 2d6af97d..39e24650 100644 --- a/src/mcp/client/streamable.py +++ b/src/mcp/client/streamable.py @@ -29,7 +29,7 @@ async def streamable_client( Client transport for streamable HTTP, with fallback to SSE. """ if await _is_old_sse_server(url, timeout): - async with sse_client(url) as (read_stream, write_stream): + async with sse_client(url, headers=headers) as (read_stream, write_stream): yield read_stream, write_stream return From cb222fb5ee8fee0a01d5a817ca0defbb86adb683 Mon Sep 17 00:00:00 2001 From: Joshua Newman Date: Fri, 4 Apr 2025 16:53:41 -0700 Subject: [PATCH 5/5] pass headers to version test too --- src/mcp/client/streamable.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/mcp/client/streamable.py b/src/mcp/client/streamable.py index 39e24650..e3b07785 100644 --- a/src/mcp/client/streamable.py +++ b/src/mcp/client/streamable.py @@ -28,7 +28,7 @@ async def streamable_client( """ Client transport for streamable HTTP, with fallback to SSE. """ - if await _is_old_sse_server(url, timeout): + if await _is_old_sse_server(url, headers=headers, timeout=timeout): async with sse_client(url, headers=headers) as (read_stream, write_stream): yield read_stream, write_stream return @@ -148,7 +148,11 @@ async def post_writer(): ) -async def _is_old_sse_server(url: str, timeout: float) -> bool: +async def _is_old_sse_server( + url: str, + headers: dict[str, Any] | None, + timeout: float, +) -> bool: """ Test whether this is an old SSE MCP server. @@ -176,6 +180,7 @@ async def _is_old_sse_server(url: str, timeout: float) -> bool: headers=( ("accept", "application/json"), ("accept", "text/event-stream"), + *(headers or {}).items(), ), ) if 400 <= response.status_code < 500: