Skip to content

Commit bbcf955

Browse files
fix(modelcontextprotocol/python-sdk#206): port fix for concurrent requests from mcp
Signed-off-by: Radek Ježek <[email protected]> Co-authored-by: Jerome <[email protected]>
1 parent 643f38f commit bbcf955

File tree

2 files changed

+39
-35
lines changed
  • apps/beeai-server/src/beeai_server/services/mcp_proxy
  • packages/acp-python-sdk/src/acp/server/lowlevel

2 files changed

+39
-35
lines changed

apps/beeai-server/src/beeai_server/services/mcp_proxy/proxy_server.py

+6-15
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import logging
22
import uuid
3-
import warnings
43
from contextlib import asynccontextmanager
54
from functools import cached_property
65

6+
import anyio
77
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
88
from kink import inject
99

@@ -19,7 +19,7 @@
1919
from acp import ServerSession, types
2020
from acp.server import Server
2121
from acp.server.models import InitializationOptions
22-
from acp.shared.session import RequestResponder, ReceiveResultT
22+
from acp.shared.session import ReceiveResultT
2323
from acp.types import (
2424
CallToolRequestParams,
2525
ClientRequest,
@@ -133,18 +133,9 @@ async def run_server(
133133
HACK: Modified server.run method that subscribes and forwards messages
134134
The default method sets Request ContextVar only for client requests, not notifications.
135135
"""
136-
with warnings.catch_warnings(record=True) as w:
137-
async with ServerSession(read_stream, write_stream, initialization_options) as session:
138-
async with self._provider_container.forward_notifications(session):
136+
async with ServerSession(read_stream, write_stream, initialization_options) as session:
137+
async with self._provider_container.forward_notifications(session):
138+
async with anyio.create_task_group() as tg:
139139
async for message in session.incoming_messages:
140140
logger.debug(f"Received message: {message}")
141-
142-
match message:
143-
case RequestResponder(request=types.ClientRequest(root=req)) as responder:
144-
with responder:
145-
await self.app._handle_request(message, req, session, raise_exceptions)
146-
case types.ClientNotification(root=notify):
147-
await self.app._handle_notification(notify)
148-
149-
for warning in w:
150-
logger.info(f"Warning: {warning.category.__name__}: {warning.message}")
141+
tg.start_soon(self.app._handle_message, message, session, raise_exceptions)

packages/acp-python-sdk/src/acp/server/lowlevel/server.py

+33-20
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,10 @@ async def main():
6868
import logging
6969
import warnings
7070
from collections.abc import Awaitable, Callable
71+
from contextlib import AsyncExitStack
7172
from typing import Any, Sequence
7273

74+
import anyio
7375
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
7476
from pydantic import AnyUrl
7577

@@ -538,30 +540,41 @@ async def run(
538540
# in-process servers.
539541
raise_exceptions: bool = False,
540542
):
541-
with warnings.catch_warnings(record=True) as w:
542-
async with ServerSession(
543-
read_stream, write_stream, initialization_options
544-
) as session:
543+
async with AsyncExitStack() as stack:
544+
session = await stack.enter_async_context(
545+
ServerSession(read_stream, write_stream, initialization_options)
546+
)
547+
548+
async with anyio.create_task_group() as tg:
545549
async for message in session.incoming_messages:
546550
logger.debug(f"Received message: {message}")
547551

548-
match message:
549-
case (
550-
RequestResponder(
551-
request=types.ClientRequest(root=req)
552-
) as responder
553-
):
554-
with responder:
555-
await self._handle_request(
556-
message, req, session, raise_exceptions
557-
)
558-
case types.ClientNotification(root=notify):
559-
await self._handle_notification(notify)
560-
561-
for warning in w:
562-
logger.info(
563-
f"Warning: {warning.category.__name__}: {warning.message}"
552+
tg.start_soon(
553+
self._handle_message, message, session, raise_exceptions
554+
)
555+
556+
async def _handle_message(
557+
self,
558+
message: RequestResponder[types.ClientRequest, types.ServerResult]
559+
| types.ClientNotification
560+
| Exception,
561+
session: ServerSession,
562+
raise_exceptions: bool = False,
563+
):
564+
with warnings.catch_warnings(record=True) as w:
565+
match message:
566+
case (
567+
RequestResponder(request=types.ClientRequest(root=req)) as responder
568+
):
569+
with responder:
570+
await self._handle_request(
571+
message, req, session, raise_exceptions
564572
)
573+
case types.ClientNotification(root=notify):
574+
await self._handle_notification(notify)
575+
576+
for warning in w:
577+
logger.info(f"Warning: {warning.category.__name__}: {warning.message}")
565578

566579
async def _handle_request(
567580
self,

0 commit comments

Comments
 (0)