diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index 4f6241a7..0194c02a 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -65,8 +65,9 @@ async def sse_reader( logger.info( f"Received endpoint URL: {endpoint_url}" ) - + url_parsed = urlparse(url) + endpoint_parsed = urlparse(endpoint_url) if ( url_parsed.netloc != endpoint_parsed.netloc diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index d051c25b..6ec4c165 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -32,6 +32,7 @@ async def handle_sse(request): """ import logging +import re from contextlib import asynccontextmanager from typing import Any from urllib.parse import quote @@ -95,7 +96,14 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send): write_stream, write_stream_reader = anyio.create_memory_object_stream(0) session_id = uuid4() - session_uri = f"{quote(self._endpoint)}?session_id={session_id.hex}" + request_path = scope["path"] + + match = re.match(r"^/([^/]+(?:/mcp)?)/sse$", request_path) + mount_prefix = match.group(1) if match else "" + + session_uri = f"/{quote(mount_prefix)}{quote(self._endpoint)}" + session_uri += f"?session_id={session_id.hex}" + self._read_stream_writers[session_id] = read_stream_writer logger.debug(f"Created new session with ID: {session_id}")