Skip to content

Commit d3f0dea

Browse files
committed
Change implementation to take in auth headers
1 parent 785964e commit d3f0dea

File tree

2 files changed

+153
-69
lines changed

2 files changed

+153
-69
lines changed

src/mcp/client/streamable_http.py

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -78,13 +78,12 @@ class AuthClientProvider(Protocol):
7878
"""Base class that can be extended to implement custom client-to-server
7979
authentication"""
8080

81-
async def get_token(self) -> str:
82-
"""Get a token for authenticating to an MCP server. The token is assumed to
83-
be short-lived; clients may call this API multiple times per
84-
request to an MCP server.
81+
async def get_auth_headers(self) -> dict[str, str]:
82+
"""Gets auth headers for authenticating to an MCP server.
83+
Clients may call this API multiple times per request to an MCP server.
8584
8685
Returns:
87-
str: The authentication token.
86+
dict[str, str]: The authentication headers.
8887
"""
8988
...
9089

@@ -129,23 +128,22 @@ def _update_headers_with_session(
129128
headers[MCP_SESSION_ID] = self.session_id
130129
return headers
131130

132-
async def _update_headers_with_token(
131+
async def _update_headers_with_auth_headers(
133132
self, base_headers: dict[str, str]
134133
) -> dict[str, str]:
135-
"""Update headers with token if token provider is specified and authorization
136-
header is not present."""
137-
if self.auth_client_provider is None or "Authorization" in base_headers:
134+
"""Update headers with auth_headers if auth client provider is specified.
135+
The headers are merged giving precedence to the base_headers to
136+
avoid overwriting existing Authorization headers"""
137+
if self.auth_client_provider is None:
138138
return base_headers
139139

140-
token = await self.auth_client_provider.get_token()
141-
headers = base_headers.copy()
142-
headers["Authorization"] = f"Bearer {token}"
143-
return headers
140+
auth_headers = await self.auth_client_provider.get_auth_headers()
141+
return {**auth_headers, **base_headers}
144142

145143
async def _update_headers(self, base_headers: dict[str, str]) -> dict[str, str]:
146144
"""Update headers with session ID and token if available."""
147145
headers = self._update_headers_with_session(base_headers)
148-
headers = await self._update_headers_with_token(headers)
146+
headers = await self._update_headers_with_auth_headers(headers)
149147
return headers
150148

151149
def _is_initialization_request(self, message: JSONRPCMessage) -> bool:
@@ -252,7 +250,6 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:
252250
original_request_id = None
253251
if isinstance(ctx.session_message.message.root, JSONRPCRequest):
254252
original_request_id = ctx.session_message.message.root.id
255-
256253
async with aconnect_sse(
257254
ctx.client,
258255
"GET",
@@ -275,6 +272,16 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:
275272
if is_complete:
276273
break
277274

275+
async def _is_testing_header_capture(self, response: httpx.Response) -> str | None:
276+
try:
277+
content = await response.aread()
278+
if content.decode().startswith("[TESTING_HEADER_CAPTURE]"):
279+
return content.decode()
280+
except Exception as _:
281+
return None
282+
283+
return None
284+
278285
async def _handle_post_request(self, ctx: RequestContext) -> None:
279286
"""Handle a POST request with response processing."""
280287
headers = await self._update_headers(ctx.headers)
@@ -299,12 +306,24 @@ async def _handle_post_request(self, ctx: RequestContext) -> None:
299306
)
300307
return
301308

309+
if response.status_code == 418:
310+
test_error_message = await self._is_testing_header_capture(response)
311+
# If this is coming from the test case return the response content
312+
if test_error_message and isinstance(message.root, JSONRPCRequest):
313+
jsonrpc_error = JSONRPCError(
314+
jsonrpc="2.0",
315+
id=message.root.id,
316+
error=ErrorData(code=32600, message=test_error_message),
317+
)
318+
session_message = SessionMessage(JSONRPCMessage(jsonrpc_error))
319+
await ctx.read_stream_writer.send(session_message)
320+
return
321+
302322
response.raise_for_status()
303323
if is_initialization:
304324
self._maybe_extract_session_id_from_response(response)
305325

306326
content_type = response.headers.get(CONTENT_TYPE, "").lower()
307-
308327
if content_type.startswith(JSON):
309328
await self._handle_json_response(response, ctx.read_stream_writer)
310329
elif content_type.startswith(SSE):

tests/shared/test_streamable_http.py

Lines changed: 118 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
Contains tests for both server and client sides of the StreamableHTTP transport.
55
"""
66

7+
import json
78
import multiprocessing
89
import socket
910
import time
@@ -18,6 +19,8 @@
1819
import uvicorn
1920
from pydantic import AnyUrl
2021
from starlette.applications import Starlette
22+
from starlette.requests import Request
23+
from starlette.responses import Response
2124
from starlette.routing import Mount
2225

2326
import mcp.types as types
@@ -244,8 +247,46 @@ def create_app(
244247
return app
245248

246249

250+
def create_header_capture_app() -> Starlette:
251+
"""Implement a minimal Starlette app that intercepts every request,
252+
extracts its headers, and responds with status 418 (Test Status code),
253+
embedding the captured headers as the JSON response body.
254+
We use this server solely to verify that the MCP Server is forwarding
255+
headers correctly."""
256+
257+
# Create a wrapper that captures headers and returns them in error response
258+
async def header_capture_wrapper(scope, receive, send):
259+
# Capture headers
260+
request = Request(scope, receive=receive)
261+
headers = dict(request.headers)
262+
263+
# Return error response with headers in body
264+
response = Response(
265+
"[TESTING_HEADER_CAPTURE]:" + json.dumps({"headers": headers}),
266+
status_code=418,
267+
)
268+
await response(scope, receive, send)
269+
270+
# Create an ASGI application that uses our wrapper
271+
app = Starlette(
272+
debug=True,
273+
routes=[
274+
Mount("/mcp", app=header_capture_wrapper),
275+
],
276+
)
277+
278+
return app
279+
280+
281+
def _get_captured_headrs(str) -> dict[str, str]:
282+
return json.loads(str.split("[TESTING_HEADER_CAPTURE]:")[1])["headers"]
283+
284+
247285
def run_server(
248-
port: int, is_json_response_enabled=False, event_store: EventStore | None = None
286+
port: int,
287+
is_json_response_enabled=False,
288+
event_store: EventStore | None = None,
289+
testing_header_capture: bool = False,
249290
) -> None:
250291
"""Run the test server.
251292
@@ -255,7 +296,11 @@ def run_server(
255296
event_store: Optional event store for testing resumability.
256297
"""
257298

258-
app = create_app(is_json_response_enabled, event_store)
299+
if testing_header_capture:
300+
app = create_header_capture_app()
301+
else:
302+
app = create_app(is_json_response_enabled, event_store)
303+
259304
# Configure server
260305
config = uvicorn.Config(
261306
app=app,
@@ -296,33 +341,48 @@ def json_server_port() -> int:
296341
return s.getsockname()[1]
297342

298343

299-
@pytest.fixture
300-
def basic_server(basic_server_port: int) -> Generator[None, None, None]:
301-
"""Start a basic server."""
344+
def _start_basic_server(
345+
basic_server_port: int, testing_header_capture: bool
346+
) -> Generator[None, None, None]:
302347
proc = multiprocessing.Process(
303-
target=run_server, kwargs={"port": basic_server_port}, daemon=True
348+
target=run_server,
349+
kwargs={
350+
"port": basic_server_port,
351+
"testing_header_capture": testing_header_capture,
352+
},
353+
daemon=True,
304354
)
305355
proc.start()
306356

307357
# Wait for server to be running
308358
max_attempts = 20
309-
attempt = 0
310-
while attempt < max_attempts:
359+
for attempt in range(max_attempts):
311360
try:
312361
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
313362
s.connect(("127.0.0.1", basic_server_port))
314363
break
315364
except ConnectionRefusedError:
316365
time.sleep(0.1)
317-
attempt += 1
318366
else:
319367
raise RuntimeError(f"Server failed to start after {max_attempts} attempts")
320368

321-
yield
369+
try:
370+
yield
371+
finally:
372+
proc.kill()
373+
proc.join(timeout=2)
322374

323-
# Clean up
324-
proc.kill()
325-
proc.join(timeout=2)
375+
376+
@pytest.fixture
377+
def basic_server(basic_server_port: int) -> Generator[None, None, None]:
378+
yield from _start_basic_server(basic_server_port, testing_header_capture=False)
379+
380+
381+
@pytest.fixture
382+
def basic_server_with_header_capture(
383+
basic_server_port: int,
384+
) -> Generator[None, None, None]:
385+
yield from _start_basic_server(basic_server_port, testing_header_capture=True)
326386

327387

328388
@pytest.fixture
@@ -1232,79 +1292,84 @@ class MockAuthClientProvider:
12321292
def __init__(self, token: str):
12331293
self.token = token
12341294

1235-
async def get_token(self) -> str:
1236-
return self.token
1295+
async def get_auth_headers(self) -> dict[str, str]:
1296+
return {"Authorization": f"Bearer {self.token}"}
12371297

12381298

12391299
@pytest.mark.anyio
1240-
async def test_auth_client_provider_headers(basic_server, basic_server_url):
1300+
async def test_auth_client_provider_headers(
1301+
basic_server_with_header_capture, basic_server_url
1302+
):
12411303
"""Test that auth token provider correctly sets Authorization header."""
12421304
# Create a mock token provider
1243-
client_provider = MockAuthClientProvider("test-token-123")
1244-
client_provider.get_token = AsyncMock(return_value="test-token-123")
1305+
client_provider = MockAuthClientProvider("short-lived-token-123")
12451306

12461307
# Create client with token provider
12471308
async with streamablehttp_client(
12481309
f"{basic_server_url}/mcp", auth_client_provider=client_provider
12491310
) as (read_stream, write_stream, _):
12501311
async with ClientSession(read_stream, write_stream) as session:
12511312
# Initialize the session
1252-
result = await session.initialize()
1253-
assert isinstance(result, InitializeResult)
1254-
1255-
# Make a request to verify headers
1256-
tools = await session.list_tools()
1257-
assert len(tools.tools) == 4
1258-
1259-
client_provider.get_token.assert_called()
1313+
with pytest.raises(McpError) as mcpError:
1314+
_ = await session.initialize()
1315+
assert (
1316+
_get_captured_headrs(mcpError.value.error.message)["Authorization"]
1317+
== "Bearer short-lived-token-123"
1318+
)
12601319

12611320

12621321
@pytest.mark.anyio
1263-
async def test_auth_client_provider_token_update(basic_server, basic_server_url):
1322+
async def test_auth_client_provider_token_called_on_every_request(
1323+
basic_server_with_header_capture, basic_server_url
1324+
):
12641325
"""Test that auth token provider can return different tokens."""
12651326
# Create a dynamic token provider
1266-
client_provider = MockAuthClientProvider("test-token-123")
1267-
client_provider.get_token = AsyncMock(return_value="test-token-123")
1327+
client_provider = MockAuthClientProvider("short-lived-token-123")
12681328

1269-
# Create client with dynamic token provider
12701329
async with streamablehttp_client(
12711330
f"{basic_server_url}/mcp", auth_client_provider=client_provider
12721331
) as (read_stream, write_stream, _):
12731332
async with ClientSession(read_stream, write_stream) as session:
12741333
# Initialize the session
1275-
result = await session.initialize()
1276-
assert isinstance(result, InitializeResult)
1277-
1278-
# Make multiple requests to verify token updates
1279-
for i in range(3):
1280-
tools = await session.list_tools()
1281-
assert len(tools.tools) == 4
1334+
with pytest.raises(McpError) as mcpError:
1335+
_ = await session.initialize()
1336+
assert (
1337+
_get_captured_headrs(mcpError.value.error.message)["Authorization"]
1338+
== "Bearer short-lived-token-123"
1339+
)
12821340

1283-
client_provider.get_token.call_count > 1
1341+
# Mock a new token and ensure the new token is returned
1342+
client_provider.get_auth_headers = AsyncMock(
1343+
return_value={"Authorization": "Bearer short-lived-token-456"}
1344+
)
1345+
with pytest.raises(McpError) as mcpError:
1346+
_ = await session.initialize()
1347+
assert (
1348+
_get_captured_headrs(mcpError.value.error.message)["Authorization"]
1349+
== "Bearer short-lived-token-456"
1350+
)
12841351

12851352

12861353
@pytest.mark.anyio
12871354
async def test_auth_client_provider_headers_not_overridden(
1288-
basic_server, basic_server_url
1355+
basic_server_with_header_capture, basic_server_url
12891356
):
1290-
"""Test that auth token provider correctly sets Authorization header."""
1357+
"""Test that provided headers override auth client provider headers."""
12911358
# Create a mock token provider
1292-
client_provider = MockAuthClientProvider("test-token-123")
1293-
client_provider.get_token = AsyncMock(return_value="test-token-123")
1359+
client_provider = MockAuthClientProvider("short-lived-token")
12941360

1295-
# Create client with token provider
1361+
# Create client with token provider and custom headers
1362+
custom_headers = {"Authorization": "Bearer original-long-lived-token"}
12961363
async with streamablehttp_client(
12971364
f"{basic_server_url}/mcp",
12981365
auth_client_provider=client_provider,
1299-
headers={"Authorization": "test-token-123"},
1366+
headers=custom_headers,
13001367
) as (read_stream, write_stream, _):
13011368
async with ClientSession(read_stream, write_stream) as session:
1302-
# Initialize the session
1303-
result = await session.initialize()
1304-
assert isinstance(result, InitializeResult)
1305-
1306-
# Make a request to verify headers
1307-
tools = await session.list_tools()
1308-
assert len(tools.tools) == 4
1309-
1310-
client_provider.get_token.assert_not_called()
1369+
# Original token is used and not short-lived-token from the provider
1370+
with pytest.raises(McpError) as mcpError:
1371+
_ = await session.initialize()
1372+
assert (
1373+
_get_captured_headrs(mcpError.value.error.message)["Authorization"]
1374+
== "Bearer original-long-lived-token"
1375+
)

0 commit comments

Comments
 (0)