diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs new file mode 100644 index 00000000..e69de29b diff --git a/.gitignore b/.gitignore index bb25f9e4..54006f93 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ .DS_Store +scratch/ # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 643e1a27..c0008b32 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -68,9 +68,10 @@ async def main(): import logging import warnings from collections.abc import Awaitable, Callable -from contextlib import AbstractAsyncContextManager, asynccontextmanager +from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager from typing import Any, AsyncIterator, Generic, Sequence, TypeVar +import anyio from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import AnyUrl @@ -469,41 +470,47 @@ async def run( # in-process servers. raise_exceptions: bool = False, ): - with warnings.catch_warnings(record=True) as w: - from contextlib import AsyncExitStack - - async with AsyncExitStack() as stack: - lifespan_context = await stack.enter_async_context(self.lifespan(self)) - session = await stack.enter_async_context( - ServerSession(read_stream, write_stream, initialization_options) - ) + async with AsyncExitStack() as stack: + lifespan_context = await stack.enter_async_context(self.lifespan(self)) + session = await stack.enter_async_context( + ServerSession(read_stream, write_stream, initialization_options) + ) + async with anyio.create_task_group() as tg: async for message in session.incoming_messages: logger.debug(f"Received message: {message}") - match message: - case ( - RequestResponder( - request=types.ClientRequest(root=req) - ) as responder - ): - with responder: - await self._handle_request( - message, - req, - session, - lifespan_context, - raise_exceptions, - ) - case types.ClientNotification(root=notify): - await self._handle_notification(notify) - - for warning in w: - logger.info( - "Warning: %s: %s", - warning.category.__name__, - warning.message, + tg.start_soon( + self._handle_message, + message, + session, + lifespan_context, + raise_exceptions, + ) + + async def _handle_message( + self, + message: RequestResponder[types.ClientRequest, types.ServerResult] + | types.ClientNotification + | Exception, + session: ServerSession, + lifespan_context: LifespanResultT, + raise_exceptions: bool = False, + ): + with warnings.catch_warnings(record=True) as w: + match message: + case ( + RequestResponder(request=types.ClientRequest(root=req)) as responder + ): + with responder: + await self._handle_request( + message, req, session, lifespan_context, raise_exceptions ) + case types.ClientNotification(root=notify): + await self._handle_notification(notify) + + for warning in w: + logger.info(f"Warning: {warning.category.__name__}: {warning.message}") async def _handle_request( self, diff --git a/tests/issues/test_188_concurrency.py b/tests/issues/test_188_concurrency.py new file mode 100644 index 00000000..a56c0d30 --- /dev/null +++ b/tests/issues/test_188_concurrency.py @@ -0,0 +1,49 @@ +import anyio +from pydantic import AnyUrl + +from mcp.server.fastmcp import FastMCP +from mcp.shared.memory import ( + create_connected_server_and_client_session as create_session, +) + +_sleep_time_seconds = 0.01 +_resource_name = "slow://slow_resource" + + +async def test_messages_are_executed_concurrently(): + server = FastMCP("test") + + @server.tool("sleep") + async def sleep_tool(): + await anyio.sleep(_sleep_time_seconds) + return "done" + + @server.resource(_resource_name) + async def slow_resource(): + await anyio.sleep(_sleep_time_seconds) + return "slow" + + async with create_session(server._mcp_server) as client_session: + start_time = anyio.current_time() + async with anyio.create_task_group() as tg: + for _ in range(10): + tg.start_soon(client_session.call_tool, "sleep") + tg.start_soon(client_session.read_resource, AnyUrl(_resource_name)) + + end_time = anyio.current_time() + + duration = end_time - start_time + assert duration < 3 * _sleep_time_seconds + print(duration) + + +def main(): + anyio.run(test_messages_are_executed_concurrently) + + +if __name__ == "__main__": + import logging + + logging.basicConfig(level=logging.DEBUG) + + main()