4
4
Contains tests for both server and client sides of the StreamableHTTP transport.
5
5
"""
6
6
7
+ import json
7
8
import multiprocessing
8
9
import socket
9
10
import time
18
19
import uvicorn
19
20
from pydantic import AnyUrl
20
21
from starlette .applications import Starlette
22
+ from starlette .requests import Request
23
+ from starlette .responses import Response
21
24
from starlette .routing import Mount
22
25
23
26
import mcp .types as types
@@ -244,8 +247,46 @@ def create_app(
244
247
return app
245
248
246
249
250
+ def create_header_capture_app () -> Starlette :
251
+ """Implement a minimal Starlette app that intercepts every request,
252
+ extracts its headers, and responds with status 418 (Test Status code),
253
+ embedding the captured headers as the JSON response body.
254
+ We use this server solely to verify that the MCP Server is forwarding
255
+ headers correctly."""
256
+
257
+ # Create a wrapper that captures headers and returns them in error response
258
+ async def header_capture_wrapper (scope , receive , send ):
259
+ # Capture headers
260
+ request = Request (scope , receive = receive )
261
+ headers = dict (request .headers )
262
+
263
+ # Return error response with headers in body
264
+ response = Response (
265
+ "[TESTING_HEADER_CAPTURE]:" + json .dumps ({"headers" : headers }),
266
+ status_code = 418 ,
267
+ )
268
+ await response (scope , receive , send )
269
+
270
+ # Create an ASGI application that uses our wrapper
271
+ app = Starlette (
272
+ debug = True ,
273
+ routes = [
274
+ Mount ("/mcp" , app = header_capture_wrapper ),
275
+ ],
276
+ )
277
+
278
+ return app
279
+
280
+
281
+ def _get_captured_headrs (str ) -> dict [str , str ]:
282
+ return json .loads (str .split ("[TESTING_HEADER_CAPTURE]:" )[1 ])["headers" ]
283
+
284
+
247
285
def run_server (
248
- port : int , is_json_response_enabled = False , event_store : EventStore | None = None
286
+ port : int ,
287
+ is_json_response_enabled = False ,
288
+ event_store : EventStore | None = None ,
289
+ testing_header_capture : bool = False ,
249
290
) -> None :
250
291
"""Run the test server.
251
292
@@ -255,7 +296,11 @@ def run_server(
255
296
event_store: Optional event store for testing resumability.
256
297
"""
257
298
258
- app = create_app (is_json_response_enabled , event_store )
299
+ if testing_header_capture :
300
+ app = create_header_capture_app ()
301
+ else :
302
+ app = create_app (is_json_response_enabled , event_store )
303
+
259
304
# Configure server
260
305
config = uvicorn .Config (
261
306
app = app ,
@@ -296,33 +341,48 @@ def json_server_port() -> int:
296
341
return s .getsockname ()[1 ]
297
342
298
343
299
- @ pytest . fixture
300
- def basic_server ( basic_server_port : int ) -> Generator [ None , None , None ]:
301
- """Start a basic server."""
344
+ def _start_basic_server (
345
+ basic_server_port : int , testing_header_capture : bool
346
+ ) -> Generator [ None , None , None ]:
302
347
proc = multiprocessing .Process (
303
- target = run_server , kwargs = {"port" : basic_server_port }, daemon = True
348
+ target = run_server ,
349
+ kwargs = {
350
+ "port" : basic_server_port ,
351
+ "testing_header_capture" : testing_header_capture ,
352
+ },
353
+ daemon = True ,
304
354
)
305
355
proc .start ()
306
356
307
357
# Wait for server to be running
308
358
max_attempts = 20
309
- attempt = 0
310
- while attempt < max_attempts :
359
+ for attempt in range (max_attempts ):
311
360
try :
312
361
with socket .socket (socket .AF_INET , socket .SOCK_STREAM ) as s :
313
362
s .connect (("127.0.0.1" , basic_server_port ))
314
363
break
315
364
except ConnectionRefusedError :
316
365
time .sleep (0.1 )
317
- attempt += 1
318
366
else :
319
367
raise RuntimeError (f"Server failed to start after { max_attempts } attempts" )
320
368
321
- yield
369
+ try :
370
+ yield
371
+ finally :
372
+ proc .kill ()
373
+ proc .join (timeout = 2 )
322
374
323
- # Clean up
324
- proc .kill ()
325
- proc .join (timeout = 2 )
375
+
376
+ @pytest .fixture
377
+ def basic_server (basic_server_port : int ) -> Generator [None , None , None ]:
378
+ yield from _start_basic_server (basic_server_port , testing_header_capture = False )
379
+
380
+
381
+ @pytest .fixture
382
+ def basic_server_with_header_capture (
383
+ basic_server_port : int ,
384
+ ) -> Generator [None , None , None ]:
385
+ yield from _start_basic_server (basic_server_port , testing_header_capture = True )
326
386
327
387
328
388
@pytest .fixture
@@ -1232,79 +1292,84 @@ class MockAuthClientProvider:
1232
1292
def __init__ (self , token : str ):
1233
1293
self .token = token
1234
1294
1235
- async def get_token (self ) -> str :
1236
- return self .token
1295
+ async def get_auth_headers (self ) -> dict [ str , str ] :
1296
+ return { "Authorization" : f"Bearer { self .token } " }
1237
1297
1238
1298
1239
1299
@pytest .mark .anyio
1240
- async def test_auth_client_provider_headers (basic_server , basic_server_url ):
1300
+ async def test_auth_client_provider_headers (
1301
+ basic_server_with_header_capture , basic_server_url
1302
+ ):
1241
1303
"""Test that auth token provider correctly sets Authorization header."""
1242
1304
# Create a mock token provider
1243
- client_provider = MockAuthClientProvider ("test-token-123" )
1244
- client_provider .get_token = AsyncMock (return_value = "test-token-123" )
1305
+ client_provider = MockAuthClientProvider ("short-lived-token-123" )
1245
1306
1246
1307
# Create client with token provider
1247
1308
async with streamablehttp_client (
1248
1309
f"{ basic_server_url } /mcp" , auth_client_provider = client_provider
1249
1310
) as (read_stream , write_stream , _ ):
1250
1311
async with ClientSession (read_stream , write_stream ) as session :
1251
1312
# Initialize the session
1252
- result = await session .initialize ()
1253
- assert isinstance (result , InitializeResult )
1254
-
1255
- # Make a request to verify headers
1256
- tools = await session .list_tools ()
1257
- assert len (tools .tools ) == 4
1258
-
1259
- client_provider .get_token .assert_called ()
1313
+ with pytest .raises (McpError ) as mcpError :
1314
+ _ = await session .initialize ()
1315
+ assert (
1316
+ _get_captured_headrs (mcpError .value .error .message )["Authorization" ]
1317
+ == "Bearer short-lived-token-123"
1318
+ )
1260
1319
1261
1320
1262
1321
@pytest .mark .anyio
1263
- async def test_auth_client_provider_token_update (basic_server , basic_server_url ):
1322
+ async def test_auth_client_provider_token_called_on_every_request (
1323
+ basic_server_with_header_capture , basic_server_url
1324
+ ):
1264
1325
"""Test that auth token provider can return different tokens."""
1265
1326
# Create a dynamic token provider
1266
- client_provider = MockAuthClientProvider ("test-token-123" )
1267
- client_provider .get_token = AsyncMock (return_value = "test-token-123" )
1327
+ client_provider = MockAuthClientProvider ("short-lived-token-123" )
1268
1328
1269
- # Create client with dynamic token provider
1270
1329
async with streamablehttp_client (
1271
1330
f"{ basic_server_url } /mcp" , auth_client_provider = client_provider
1272
1331
) as (read_stream , write_stream , _ ):
1273
1332
async with ClientSession (read_stream , write_stream ) as session :
1274
1333
# Initialize the session
1275
- result = await session .initialize ()
1276
- assert isinstance (result , InitializeResult )
1277
-
1278
- # Make multiple requests to verify token updates
1279
- for i in range (3 ):
1280
- tools = await session .list_tools ()
1281
- assert len (tools .tools ) == 4
1334
+ with pytest .raises (McpError ) as mcpError :
1335
+ _ = await session .initialize ()
1336
+ assert (
1337
+ _get_captured_headrs (mcpError .value .error .message )["Authorization" ]
1338
+ == "Bearer short-lived-token-123"
1339
+ )
1282
1340
1283
- client_provider .get_token .call_count > 1
1341
+ # Mock a new token and ensure the new token is returned
1342
+ client_provider .get_auth_headers = AsyncMock (
1343
+ return_value = {"Authorization" : "Bearer short-lived-token-456" }
1344
+ )
1345
+ with pytest .raises (McpError ) as mcpError :
1346
+ _ = await session .initialize ()
1347
+ assert (
1348
+ _get_captured_headrs (mcpError .value .error .message )["Authorization" ]
1349
+ == "Bearer short-lived-token-456"
1350
+ )
1284
1351
1285
1352
1286
1353
@pytest .mark .anyio
1287
1354
async def test_auth_client_provider_headers_not_overridden (
1288
- basic_server , basic_server_url
1355
+ basic_server_with_header_capture , basic_server_url
1289
1356
):
1290
- """Test that auth token provider correctly sets Authorization header ."""
1357
+ """Test that provided headers override auth client provider headers ."""
1291
1358
# Create a mock token provider
1292
- client_provider = MockAuthClientProvider ("test-token-123" )
1293
- client_provider .get_token = AsyncMock (return_value = "test-token-123" )
1359
+ client_provider = MockAuthClientProvider ("short-lived-token" )
1294
1360
1295
- # Create client with token provider
1361
+ # Create client with token provider and custom headers
1362
+ custom_headers = {"Authorization" : "Bearer original-long-lived-token" }
1296
1363
async with streamablehttp_client (
1297
1364
f"{ basic_server_url } /mcp" ,
1298
1365
auth_client_provider = client_provider ,
1299
- headers = { "Authorization" : "test-token-123" } ,
1366
+ headers = custom_headers ,
1300
1367
) as (read_stream , write_stream , _ ):
1301
1368
async with ClientSession (read_stream , write_stream ) as session :
1302
- # Initialize the session
1303
- result = await session .initialize ()
1304
- assert isinstance (result , InitializeResult )
1305
-
1306
- # Make a request to verify headers
1307
- tools = await session .list_tools ()
1308
- assert len (tools .tools ) == 4
1309
-
1310
- client_provider .get_token .assert_not_called ()
1369
+ # Original token is used and not short-lived-token from the provider
1370
+ with pytest .raises (McpError ) as mcpError :
1371
+ _ = await session .initialize ()
1372
+ assert (
1373
+ _get_captured_headrs (mcpError .value .error .message )["Authorization" ]
1374
+ == "Bearer original-long-lived-token"
1375
+ )
0 commit comments