Skip to content

Commit 92ba4f4

Browse files
authored
Merge pull request #206 from modelcontextprotocol/jerome/fix/188
Jerome/fix/188
2 parents 7c47d1f + fbf4acc commit 92ba4f4

File tree

4 files changed

+88
-31
lines changed

4 files changed

+88
-31
lines changed

.git-blame-ignore-revs

Whitespace-only changes.

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
.DS_Store
2+
scratch/
23

34
# Byte-compiled / optimized / DLL files
45
__pycache__/

src/mcp/server/lowlevel/server.py

+38-31
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,10 @@ async def main():
6868
import logging
6969
import warnings
7070
from collections.abc import Awaitable, Callable
71-
from contextlib import AbstractAsyncContextManager, asynccontextmanager
71+
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
7272
from typing import Any, AsyncIterator, Generic, Sequence, TypeVar
7373

74+
import anyio
7475
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
7576
from pydantic import AnyUrl
7677

@@ -469,41 +470,47 @@ async def run(
469470
# in-process servers.
470471
raise_exceptions: bool = False,
471472
):
472-
with warnings.catch_warnings(record=True) as w:
473-
from contextlib import AsyncExitStack
474-
475-
async with AsyncExitStack() as stack:
476-
lifespan_context = await stack.enter_async_context(self.lifespan(self))
477-
session = await stack.enter_async_context(
478-
ServerSession(read_stream, write_stream, initialization_options)
479-
)
473+
async with AsyncExitStack() as stack:
474+
lifespan_context = await stack.enter_async_context(self.lifespan(self))
475+
session = await stack.enter_async_context(
476+
ServerSession(read_stream, write_stream, initialization_options)
477+
)
480478

479+
async with anyio.create_task_group() as tg:
481480
async for message in session.incoming_messages:
482481
logger.debug(f"Received message: {message}")
483482

484-
match message:
485-
case (
486-
RequestResponder(
487-
request=types.ClientRequest(root=req)
488-
) as responder
489-
):
490-
with responder:
491-
await self._handle_request(
492-
message,
493-
req,
494-
session,
495-
lifespan_context,
496-
raise_exceptions,
497-
)
498-
case types.ClientNotification(root=notify):
499-
await self._handle_notification(notify)
500-
501-
for warning in w:
502-
logger.info(
503-
"Warning: %s: %s",
504-
warning.category.__name__,
505-
warning.message,
483+
tg.start_soon(
484+
self._handle_message,
485+
message,
486+
session,
487+
lifespan_context,
488+
raise_exceptions,
489+
)
490+
491+
async def _handle_message(
492+
self,
493+
message: RequestResponder[types.ClientRequest, types.ServerResult]
494+
| types.ClientNotification
495+
| Exception,
496+
session: ServerSession,
497+
lifespan_context: LifespanResultT,
498+
raise_exceptions: bool = False,
499+
):
500+
with warnings.catch_warnings(record=True) as w:
501+
match message:
502+
case (
503+
RequestResponder(request=types.ClientRequest(root=req)) as responder
504+
):
505+
with responder:
506+
await self._handle_request(
507+
message, req, session, lifespan_context, raise_exceptions
506508
)
509+
case types.ClientNotification(root=notify):
510+
await self._handle_notification(notify)
511+
512+
for warning in w:
513+
logger.info(f"Warning: {warning.category.__name__}: {warning.message}")
507514

508515
async def _handle_request(
509516
self,

tests/issues/test_188_concurrency.py

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import anyio
2+
from pydantic import AnyUrl
3+
4+
from mcp.server.fastmcp import FastMCP
5+
from mcp.shared.memory import (
6+
create_connected_server_and_client_session as create_session,
7+
)
8+
9+
_sleep_time_seconds = 0.01
10+
_resource_name = "slow://slow_resource"
11+
12+
13+
async def test_messages_are_executed_concurrently():
14+
server = FastMCP("test")
15+
16+
@server.tool("sleep")
17+
async def sleep_tool():
18+
await anyio.sleep(_sleep_time_seconds)
19+
return "done"
20+
21+
@server.resource(_resource_name)
22+
async def slow_resource():
23+
await anyio.sleep(_sleep_time_seconds)
24+
return "slow"
25+
26+
async with create_session(server._mcp_server) as client_session:
27+
start_time = anyio.current_time()
28+
async with anyio.create_task_group() as tg:
29+
for _ in range(10):
30+
tg.start_soon(client_session.call_tool, "sleep")
31+
tg.start_soon(client_session.read_resource, AnyUrl(_resource_name))
32+
33+
end_time = anyio.current_time()
34+
35+
duration = end_time - start_time
36+
assert duration < 3 * _sleep_time_seconds
37+
print(duration)
38+
39+
40+
def main():
41+
anyio.run(test_messages_are_executed_concurrently)
42+
43+
44+
if __name__ == "__main__":
45+
import logging
46+
47+
logging.basicConfig(level=logging.DEBUG)
48+
49+
main()

0 commit comments

Comments
 (0)