Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Streamable HTTP client transport #416

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/mcp/client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def __init__(
list_roots_callback: ListRootsFnT | None = None,
logging_callback: LoggingFnT | None = None,
message_handler: MessageHandlerFnT | None = None,
supported_protocol_versions: tuple[str | int, ...] | None = None,
) -> None:
super().__init__(
read_stream,
Expand All @@ -109,6 +110,9 @@ def __init__(
self._list_roots_callback = list_roots_callback or _default_list_roots_callback
self._logging_callback = logging_callback or _default_logging_callback
self._message_handler = message_handler or _default_message_handler
self._supported_protocol_versions = (
supported_protocol_versions or SUPPORTED_PROTOCOL_VERSIONS
)

async def initialize(self) -> types.InitializeResult:
sampling = types.SamplingCapability()
Expand Down Expand Up @@ -137,7 +141,7 @@ async def initialize(self) -> types.InitializeResult:
types.InitializeResult,
)

if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS:
if result.protocolVersion not in self._supported_protocol_versions:
raise RuntimeError(
"Unsupported protocol version from the server: "
f"{result.protocolVersion}"
Expand Down
188 changes: 188 additions & 0 deletions src/mcp/client/streamable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
import logging
from contextlib import asynccontextmanager
from typing import Any

import anyio
import httpx
from httpx_sse import EventSource
from pydantic import TypeAdapter

import mcp.types as types
from mcp.client.sse import sse_client

logger = logging.getLogger(__name__)

STREAMABLE_PROTOCOL_VERSION = "2025-03-26"
SUPPORTED_PROTOCOL_VERSIONS: tuple[str, ...] = (
types.LATEST_PROTOCOL_VERSION,
STREAMABLE_PROTOCOL_VERSION,
)


@asynccontextmanager
async def streamable_client(
url: str,
headers: dict[str, Any] | None = None,
timeout: float = 5,
):
"""
Client transport for streamable HTTP, with fallback to SSE.
"""
if await _is_old_sse_server(url, headers=headers, timeout=timeout):
async with sse_client(url, headers=headers) as (read_stream, write_stream):
yield read_stream, write_stream
return

read_stream_writer, read_stream = anyio.create_memory_object_stream[
types.JSONRPCMessage | Exception
](0)
write_stream, write_stream_reader = anyio.create_memory_object_stream[
types.JSONRPCMessage
](0)

async def handle_response(text: str) -> None:
items = _maybe_list_adapter.validate_json(text)
if isinstance(items, types.JSONRPCMessage):
items = [items]
for item in items:
await read_stream_writer.send(item)

session_headers = headers.copy() if headers else {}

async with anyio.create_task_group() as tg:
try:
async with httpx.AsyncClient(timeout=timeout) as client:

async def sse_reader(event_source: EventSource):
try:
async for sse in event_source.aiter_sse():
logger.debug(f"Received SSE event: {sse.event}")
match sse.event:
case "message":
try:
await handle_response(sse.data)
logger.debug(
f"Received server message: {sse.data}"
)
except Exception as exc:
logger.error(
f"Error parsing server message: {exc}"
)
await read_stream_writer.send(exc)
continue
case _:
logger.warning(f"Unknown SSE event: {sse.event}")
except Exception as exc:
logger.error(f"Error in sse_reader: {exc}")
await read_stream_writer.send(exc)
finally:
await read_stream_writer.aclose()

async def post_writer():
nonlocal headers
try:
async with write_stream_reader:
async for message in write_stream_reader:
logger.debug(f"Sending client message: {message}")
response = await client.post(
url,
json=message.model_dump(
by_alias=True,
mode="json",
exclude_none=True,
),
headers=(
("accept", "application/json"),
("accept", "text/event-stream"),
*session_headers.items(),
),
)
content_type = response.headers.get("content-type")
logger.debug(
f"response {url=} "
f"content-type={content_type} "
f"body={response.text}"
)

response.raise_for_status()
match response.headers.get("mcp-session-id"):
case str() as session_id:
session_headers["mcp-session-id"] = session_id
case _:
pass

match response.headers.get("content-type"):
case "text/event-stream":
await sse_reader(EventSource(response))
case "application/json":
await handle_response(response.text)
case None:
pass
case unknown:
logger.warning(
f"Unknown content type: {unknown}"
)

logger.debug(
"Client message sent successfully: "
f"{response.status_code}"
)
except Exception as exc:
logger.error(f"Error in post_writer: {exc}", exc_info=True)
finally:
await write_stream.aclose()

tg.start_soon(post_writer)

try:
yield read_stream, write_stream
finally:
tg.cancel_scope.cancel()
finally:
await read_stream_writer.aclose()
await write_stream.aclose()


_maybe_list_adapter: TypeAdapter[types.JSONRPCMessage | list[types.JSONRPCMessage]] = (
TypeAdapter(types.JSONRPCMessage | list[types.JSONRPCMessage])
)


async def _is_old_sse_server(
url: str,
headers: dict[str, Any] | None,
timeout: float,
) -> bool:
"""
Test whether this is an old SSE MCP server.

See: https://spec.modelcontextprotocol.io/specification/2025-03-26/basic/transports/#backwards-compatibility
"""
async with httpx.AsyncClient(timeout=timeout) as client:
test_initialize_request = types.InitializeRequest(
method="initialize",
params=types.InitializeRequestParams(
protocolVersion=STREAMABLE_PROTOCOL_VERSION,
capabilities=types.ClientCapabilities(),
clientInfo=types.Implementation(name="mcp", version="0.1.0"),
),
)
response = await client.post(
url,
json=types.JSONRPCRequest(
jsonrpc="2.0",
id=1,
method=test_initialize_request.method,
params=test_initialize_request.params.model_dump(
by_alias=True, mode="json", exclude_none=True
),
).model_dump(by_alias=True, mode="json", exclude_none=True),
headers=(
("accept", "application/json"),
("accept", "text/event-stream"),
*(headers or {}).items(),
),
)
if 400 <= response.status_code < 500:
return True
return False