Skip to content

Commit da53a97

Browse files
jerome3o-anthropicdsp-ant
authored andcommitted
Made message handling concurrent
1 parent 9abfe77 commit da53a97

File tree

1 file changed

+39
-32
lines changed

1 file changed

+39
-32
lines changed

src/mcp/server/lowlevel/server.py

+39-32
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,10 @@ async def main():
6868
import logging
6969
import warnings
7070
from collections.abc import Awaitable, Callable
71-
from contextlib import AbstractAsyncContextManager, asynccontextmanager
71+
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
7272
from typing import Any, AsyncIterator, Generic, Sequence, TypeVar
7373

74+
import anyio
7475
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
7576
from pydantic import AnyUrl
7677

@@ -458,6 +459,30 @@ async def handler(req: types.CompleteRequest):
458459

459460
return decorator
460461

462+
async def _handle_message(
463+
self,
464+
message: RequestResponder[types.ClientRequest, types.ServerResult]
465+
| types.ClientNotification
466+
| Exception,
467+
session: ServerSession,
468+
lifespan_context: LifespanResultT,
469+
raise_exceptions: bool = False,
470+
):
471+
with warnings.catch_warnings(record=True) as w:
472+
match message:
473+
case (
474+
RequestResponder(request=types.ClientRequest(root=req)) as responder
475+
):
476+
with responder:
477+
await self._handle_request(
478+
message, req, session, lifespan_context, raise_exceptions
479+
)
480+
case types.ClientNotification(root=notify):
481+
await self._handle_notification(notify)
482+
483+
for warning in w:
484+
logger.info(f"Warning: {warning.category.__name__}: {warning.message}")
485+
461486
async def run(
462487
self,
463488
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
@@ -469,41 +494,23 @@ async def run(
469494
# in-process servers.
470495
raise_exceptions: bool = False,
471496
):
472-
with warnings.catch_warnings(record=True) as w:
473-
from contextlib import AsyncExitStack
474-
475-
async with AsyncExitStack() as stack:
476-
lifespan_context = await stack.enter_async_context(self.lifespan(self))
477-
session = await stack.enter_async_context(
478-
ServerSession(read_stream, write_stream, initialization_options)
479-
)
497+
async with AsyncExitStack() as stack:
498+
lifespan_context = await stack.enter_async_context(self.lifespan(self))
499+
session = await stack.enter_async_context(
500+
ServerSession(read_stream, write_stream, initialization_options)
501+
)
480502

503+
async with anyio.create_task_group() as tg:
481504
async for message in session.incoming_messages:
482505
logger.debug(f"Received message: {message}")
483506

484-
match message:
485-
case (
486-
RequestResponder(
487-
request=types.ClientRequest(root=req)
488-
) as responder
489-
):
490-
with responder:
491-
await self._handle_request(
492-
message,
493-
req,
494-
session,
495-
lifespan_context,
496-
raise_exceptions,
497-
)
498-
case types.ClientNotification(root=notify):
499-
await self._handle_notification(notify)
500-
501-
for warning in w:
502-
logger.info(
503-
"Warning: %s: %s",
504-
warning.category.__name__,
505-
warning.message,
506-
)
507+
tg.start_soon(
508+
self._handle_message,
509+
message,
510+
session,
511+
lifespan_context,
512+
raise_exceptions,
513+
)
507514

508515
async def _handle_request(
509516
self,

0 commit comments

Comments
 (0)