diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 3324dab5a..13468c3b4 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -11,7 +11,7 @@ from contextlib import asynccontextmanager from dataclasses import dataclass from datetime import timedelta -from typing import Any +from typing import Any, Protocol import anyio import httpx @@ -74,6 +74,20 @@ class RequestContext: sse_read_timeout: timedelta +class AuthClientProvider(Protocol): + """Base class that can be extended to implement custom client-to-server + authentication""" + + async def get_headers(self) -> dict[str, str]: + """Gets auth headers for authenticating to an MCP server. + Clients may call this API multiple times per request to an MCP server. + + Returns: + dict[str, str]: The authentication headers. + """ + ... + + class StreamableHTTPTransport: """StreamableHTTP client transport implementation.""" @@ -83,6 +97,7 @@ def __init__( headers: dict[str, Any] | None = None, timeout: timedelta = timedelta(seconds=30), sse_read_timeout: timedelta = timedelta(seconds=60 * 5), + auth_client_provider: AuthClientProvider | None = None, ) -> None: """Initialize the StreamableHTTP transport. @@ -102,6 +117,7 @@ def __init__( CONTENT_TYPE: JSON, **self.headers, } + self.auth_client_provider = auth_client_provider def _update_headers_with_session( self, base_headers: dict[str, str] @@ -112,6 +128,24 @@ def _update_headers_with_session( headers[MCP_SESSION_ID] = self.session_id return headers + async def _update_headers_with_auth_headers( + self, base_headers: dict[str, str] + ) -> dict[str, str]: + """Update headers with auth_headers if auth client provider is specified. + The headers are merged, giving precedence to any headers already + specified in base_headers""" + if self.auth_client_provider is None: + return base_headers + + auth_headers = await self.auth_client_provider.get_headers() + return {**auth_headers, **base_headers} + + async def _update_headers(self, base_headers: dict[str, str]) -> dict[str, str]: + """Update headers with session ID and token if available.""" + headers = self._update_headers_with_session(base_headers) + headers = await self._update_headers_with_auth_headers(headers) + return headers + def _is_initialization_request(self, message: JSONRPCMessage) -> bool: """Check if the message is an initialization request.""" return ( @@ -184,7 +218,7 @@ async def handle_get_stream( if not self.session_id: return - headers = self._update_headers_with_session(self.request_headers) + headers = await self._update_headers(self.request_headers) async with aconnect_sse( client, @@ -206,7 +240,7 @@ async def handle_get_stream( async def _handle_resumption_request(self, ctx: RequestContext) -> None: """Handle a resumption request using GET with SSE.""" - headers = self._update_headers_with_session(ctx.headers) + headers = await self._update_headers(ctx.headers) if ctx.metadata and ctx.metadata.resumption_token: headers[LAST_EVENT_ID] = ctx.metadata.resumption_token else: @@ -216,7 +250,6 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None: original_request_id = None if isinstance(ctx.session_message.message.root, JSONRPCRequest): original_request_id = ctx.session_message.message.root.id - async with aconnect_sse( ctx.client, "GET", @@ -241,7 +274,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None: async def _handle_post_request(self, ctx: RequestContext) -> None: """Handle a POST request with response processing.""" - headers = self._update_headers_with_session(ctx.headers) + headers = await self._update_headers(ctx.headers) message = ctx.session_message.message is_initialization = self._is_initialization_request(message) @@ -268,7 +301,6 @@ async def _handle_post_request(self, ctx: RequestContext) -> None: self._maybe_extract_session_id_from_response(response) content_type = response.headers.get(CONTENT_TYPE, "").lower() - if content_type.startswith(JSON): await self._handle_json_response(response, ctx.read_stream_writer) elif content_type.startswith(SSE): @@ -405,7 +437,7 @@ async def terminate_session(self, client: httpx.AsyncClient) -> None: return try: - headers = self._update_headers_with_session(self.request_headers) + headers = await self._update_headers(self.request_headers) response = await client.delete(self.url, headers=headers) if response.status_code == 405: @@ -427,6 +459,7 @@ async def streamablehttp_client( timeout: timedelta = timedelta(seconds=30), sse_read_timeout: timedelta = timedelta(seconds=60 * 5), terminate_on_close: bool = True, + auth_client_provider: AuthClientProvider | None = None, ) -> AsyncGenerator[ tuple[ MemoryObjectReceiveStream[SessionMessage | Exception], @@ -441,13 +474,21 @@ async def streamablehttp_client( `sse_read_timeout` determines how long (in seconds) the client will wait for a new event before disconnecting. All other HTTP operations are controlled by `timeout`. + `auth_client_provider` instance of `AuthClientProvider` that can be passed to + support client-to-server authentication. Before each request to the MCP Server, + the auth_client_provider.get_headers() method is invoked to retrieve headers + for authentication. Note that any headers already specified in `headers` + will take precedence over headers returned by auth_client_provider.get_headers() + Yields: Tuple containing: - read_stream: Stream for reading messages from the server - write_stream: Stream for sending messages to the server - get_session_id_callback: Function to retrieve the current session ID """ - transport = StreamableHTTPTransport(url, headers, timeout, sse_read_timeout) + transport = StreamableHTTPTransport( + url, headers, timeout, sse_read_timeout, auth_client_provider + ) read_stream_writer, read_stream = anyio.create_memory_object_stream[ SessionMessage | Exception diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index f1c7ef809..ab6d79686 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -9,6 +9,7 @@ import time from collections.abc import Generator from typing import Any +from unittest.mock import AsyncMock import anyio import httpx @@ -1223,3 +1224,56 @@ async def sampling_callback( captured_message_params.messages[0].content.text == "Server needs client sampling" ) + + +@pytest.mark.anyio +async def test_auth_client_provider_headers(basic_server, basic_server_url): + """Test that auth token provider correctly sets Authorization header.""" + # Create a mock token provider + client_provider = AsyncMock() + client_provider.get_headers.return_value = { + "Authorization": "Bearer test-token-123" + } + + # Create client with token provider + async with streamablehttp_client( + f"{basic_server_url}/mcp", auth_client_provider=client_provider + ) as (read_stream, write_stream, _): + async with ClientSession(read_stream, write_stream) as session: + # Initialize the session + result = await session.initialize() + assert isinstance(result, InitializeResult) + + # Make a request to verify headers + tools = await session.list_tools() + assert len(tools.tools) == 4 + + client_provider.get_headers.assert_called() + + +@pytest.mark.anyio +async def test_auth_client_provider_called_per_request(basic_server, basic_server_url): + """Test that auth token provider can return different tokens.""" + # Create a dynamic token provider + client_provider = AsyncMock() + client_provider.get_headers.return_value = { + "Authorization": "Bearer test-token-123" + } + + # Create client with dynamic token provider + async with streamablehttp_client( + f"{basic_server_url}/mcp", auth_client_provider=client_provider + ) as (read_stream, write_stream, _): + async with ClientSession(read_stream, write_stream) as session: + # Initialize the session + result = await session.initialize() + assert isinstance(result, InitializeResult) + + # Make multiple requests to verify token updates + for i in range(3): + tools = await session.list_tools() + assert len(tools.tools) == 4 + + # list_tools is called 3 times, but get_auth_headers is also used during + # session initialization and setup. Verify it's called at least 3 times. + assert client_provider.get_headers.call_count > 3