Skip to content

Dynamic Authorization in Streamable HTTP Client #700

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 49 additions & 8 deletions src/mcp/client/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from contextlib import asynccontextmanager
from dataclasses import dataclass
from datetime import timedelta
from typing import Any
from typing import Any, Protocol

import anyio
import httpx
Expand Down Expand Up @@ -74,6 +74,20 @@ class RequestContext:
sse_read_timeout: timedelta


class AuthClientProvider(Protocol):
"""Base class that can be extended to implement custom client-to-server
authentication"""

async def get_headers(self) -> dict[str, str]:
"""Gets auth headers for authenticating to an MCP server.
Clients may call this API multiple times per request to an MCP server.

Returns:
dict[str, str]: The authentication headers.
"""
...


class StreamableHTTPTransport:
"""StreamableHTTP client transport implementation."""

Expand All @@ -83,6 +97,7 @@ def __init__(
headers: dict[str, Any] | None = None,
timeout: timedelta = timedelta(seconds=30),
sse_read_timeout: timedelta = timedelta(seconds=60 * 5),
auth_client_provider: AuthClientProvider | None = None,
) -> None:
"""Initialize the StreamableHTTP transport.

Expand All @@ -102,6 +117,7 @@ def __init__(
CONTENT_TYPE: JSON,
**self.headers,
Copy link

@smurching smurching May 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW I would expect any auth headers passed in self.headers to take precedence over the auth_provider headers - in many APIs explicitly setting a value (e.g. explicitly setting the value of the Authorization header via headers) takes precedence over implicitly inferring it. For example, I've seen that pattern often in MLflow. But will defer to the maintainers' opinion on that - in either case, it'd be good to have a test that exercises the behavior

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call, I added the behaviour to not override passed in headers, and add a test case as well

}
self.auth_client_provider = auth_client_provider

def _update_headers_with_session(
self, base_headers: dict[str, str]
Expand All @@ -112,6 +128,24 @@ def _update_headers_with_session(
headers[MCP_SESSION_ID] = self.session_id
return headers

async def _update_headers_with_auth_headers(
self, base_headers: dict[str, str]
) -> dict[str, str]:
"""Update headers with auth_headers if auth client provider is specified.
The headers are merged, giving precedence to any headers already
specified in base_headers"""
if self.auth_client_provider is None:
return base_headers

auth_headers = await self.auth_client_provider.get_headers()
return {**auth_headers, **base_headers}

async def _update_headers(self, base_headers: dict[str, str]) -> dict[str, str]:
"""Update headers with session ID and token if available."""
headers = self._update_headers_with_session(base_headers)
headers = await self._update_headers_with_auth_headers(headers)
return headers

def _is_initialization_request(self, message: JSONRPCMessage) -> bool:
"""Check if the message is an initialization request."""
return (
Expand Down Expand Up @@ -184,7 +218,7 @@ async def handle_get_stream(
if not self.session_id:
return

headers = self._update_headers_with_session(self.request_headers)
headers = await self._update_headers(self.request_headers)

async with aconnect_sse(
client,
Expand All @@ -206,7 +240,7 @@ async def handle_get_stream(

async def _handle_resumption_request(self, ctx: RequestContext) -> None:
"""Handle a resumption request using GET with SSE."""
headers = self._update_headers_with_session(ctx.headers)
headers = await self._update_headers(ctx.headers)
if ctx.metadata and ctx.metadata.resumption_token:
headers[LAST_EVENT_ID] = ctx.metadata.resumption_token
else:
Expand All @@ -216,7 +250,6 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:
original_request_id = None
if isinstance(ctx.session_message.message.root, JSONRPCRequest):
original_request_id = ctx.session_message.message.root.id

async with aconnect_sse(
ctx.client,
"GET",
Expand All @@ -241,7 +274,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:

async def _handle_post_request(self, ctx: RequestContext) -> None:
"""Handle a POST request with response processing."""
headers = self._update_headers_with_session(ctx.headers)
headers = await self._update_headers(ctx.headers)
message = ctx.session_message.message
is_initialization = self._is_initialization_request(message)

Expand All @@ -268,7 +301,6 @@ async def _handle_post_request(self, ctx: RequestContext) -> None:
self._maybe_extract_session_id_from_response(response)

content_type = response.headers.get(CONTENT_TYPE, "").lower()

if content_type.startswith(JSON):
await self._handle_json_response(response, ctx.read_stream_writer)
elif content_type.startswith(SSE):
Expand Down Expand Up @@ -405,7 +437,7 @@ async def terminate_session(self, client: httpx.AsyncClient) -> None:
return

try:
headers = self._update_headers_with_session(self.request_headers)
headers = await self._update_headers(self.request_headers)
response = await client.delete(self.url, headers=headers)

if response.status_code == 405:
Expand All @@ -427,6 +459,7 @@ async def streamablehttp_client(
timeout: timedelta = timedelta(seconds=30),
sse_read_timeout: timedelta = timedelta(seconds=60 * 5),
terminate_on_close: bool = True,
auth_client_provider: AuthClientProvider | None = None,
) -> AsyncGenerator[
tuple[
MemoryObjectReceiveStream[SessionMessage | Exception],
Expand All @@ -441,13 +474,21 @@ async def streamablehttp_client(
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
event before disconnecting. All other HTTP operations are controlled by `timeout`.

`auth_client_provider` instance of `AuthClientProvider` that can be passed to
support client-to-server authentication. Before each request to the MCP Server,
the auth_client_provider.get_headers() method is invoked to retrieve headers
for authentication. Note that any headers already specified in `headers`
will take precedence over headers returned by auth_client_provider.get_headers()

Yields:
Tuple containing:
- read_stream: Stream for reading messages from the server
- write_stream: Stream for sending messages to the server
- get_session_id_callback: Function to retrieve the current session ID
"""
transport = StreamableHTTPTransport(url, headers, timeout, sse_read_timeout)
transport = StreamableHTTPTransport(
url, headers, timeout, sse_read_timeout, auth_client_provider
)

read_stream_writer, read_stream = anyio.create_memory_object_stream[
SessionMessage | Exception
Expand Down
54 changes: 54 additions & 0 deletions tests/shared/test_streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import time
from collections.abc import Generator
from typing import Any
from unittest.mock import AsyncMock

import anyio
import httpx
Expand Down Expand Up @@ -1223,3 +1224,56 @@ async def sampling_callback(
captured_message_params.messages[0].content.text
== "Server needs client sampling"
)


@pytest.mark.anyio
async def test_auth_client_provider_headers(basic_server, basic_server_url):
"""Test that auth token provider correctly sets Authorization header."""
# Create a mock token provider
client_provider = AsyncMock()
client_provider.get_headers.return_value = {
"Authorization": "Bearer test-token-123"
}

# Create client with token provider
async with streamablehttp_client(
f"{basic_server_url}/mcp", auth_client_provider=client_provider
) as (read_stream, write_stream, _):
async with ClientSession(read_stream, write_stream) as session:
# Initialize the session
result = await session.initialize()
assert isinstance(result, InitializeResult)

# Make a request to verify headers
tools = await session.list_tools()
assert len(tools.tools) == 4

client_provider.get_headers.assert_called()


@pytest.mark.anyio
async def test_auth_client_provider_called_per_request(basic_server, basic_server_url):
"""Test that auth token provider can return different tokens."""
# Create a dynamic token provider
client_provider = AsyncMock()
client_provider.get_headers.return_value = {
"Authorization": "Bearer test-token-123"
}

# Create client with dynamic token provider
async with streamablehttp_client(
f"{basic_server_url}/mcp", auth_client_provider=client_provider
) as (read_stream, write_stream, _):
async with ClientSession(read_stream, write_stream) as session:
# Initialize the session
result = await session.initialize()
assert isinstance(result, InitializeResult)

# Make multiple requests to verify token updates

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dumb question, where do we verify the token is actually updated?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is really hard in this testing environment to get the headers and verify the implementation. There is a mock server which hosts a list and set tools. We create a session, and add our messages to the write stream. This is then read by our transport layer and a request is sent to the server. I could not find a way to intercept or inspect this request object to verify the headers. So I just ensured the method was being called

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could try to build a custom app, that looks at the header, then calls the server, and returns the auth headers in the response headers. I will wait for the maintainer to chime in if they have better ideas on how to test this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

off the cuff

I suspect you'd want to patch:

  • create_mcp_http_client
  • and/or aconnect_sse

In the appropriate places to catch all calls with headers, and assert your headers from the provider are there. I think the mocks could just pass through to the original function

for i in range(3):
tools = await session.list_tools()
assert len(tools.tools) == 4

# list_tools is called 3 times, but get_auth_headers is also used during
# session initialization and setup. Verify it's called at least 3 times.
assert client_provider.get_headers.call_count > 3
Loading