diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 4b13709c6..5f5f873ee 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -350,76 +350,89 @@ async def _receive_loop(self) -> None: self._read_stream, self._write_stream, ): - async for message in self._read_stream: - if isinstance(message, Exception): - await self._handle_incoming(message) - elif isinstance(message.message.root, JSONRPCRequest): - validated_request = self._receive_request_type.model_validate( - message.message.root.model_dump( - by_alias=True, mode="json", exclude_none=True + async with anyio.create_task_group() as tg: + async for message in self._read_stream: + if isinstance(message, Exception): + await self._handle_incoming(message) + elif isinstance(message.message.root, JSONRPCRequest): + validated_request = self._receive_request_type.model_validate( + message.message.root.model_dump( + by_alias=True, mode="json", exclude_none=True + ) + ) + responder = RequestResponder( + request_id=message.message.root.id, + request_meta=validated_request.root.params.meta + if validated_request.root.params + else None, + request=validated_request, + session=self, + on_complete=lambda r: self._in_flight.pop( + r.request_id, None + ), + message_metadata=message.metadata, ) - ) - responder = RequestResponder( - request_id=message.message.root.id, - request_meta=validated_request.root.params.meta - if validated_request.root.params - else None, - request=validated_request, - session=self, - on_complete=lambda r: self._in_flight.pop(r.request_id, None), - message_metadata=message.metadata, - ) - self._in_flight[responder.request_id] = responder - await self._received_request(responder) + async def _handle_received_request() -> None: + await self._received_request(responder) + if not responder._completed: # type: ignore[reportPrivateUsage] + await self._handle_incoming(responder) - if not responder._completed: # type: ignore[reportPrivateUsage] - await self._handle_incoming(responder) + self._in_flight[responder.request_id] = responder + tg.start_soon(_handle_received_request) - elif isinstance(message.message.root, JSONRPCNotification): - try: - notification = self._receive_notification_type.model_validate( - message.message.root.model_dump( - by_alias=True, mode="json", exclude_none=True + elif isinstance(message.message.root, JSONRPCNotification): + try: + notification = ( + self._receive_notification_type.model_validate( + message.message.root.model_dump( + by_alias=True, mode="json", exclude_none=True + ) + ) ) - ) - # Handle cancellation notifications - if isinstance(notification.root, CancelledNotification): - cancelled_id = notification.root.params.requestId - if cancelled_id in self._in_flight: - await self._in_flight[cancelled_id].cancel() - else: - # Handle progress notifications callback - if isinstance(notification.root, ProgressNotification): - progress_token = notification.root.params.progressToken - # If there is a progress callback for this token, - # call it with the progress information - if progress_token in self._progress_callbacks: - callback = self._progress_callbacks[progress_token] - await callback( - notification.root.params.progress, - notification.root.params.total, - notification.root.params.message, + # Handle cancellation notifications + if isinstance(notification.root, CancelledNotification): + cancelled_id = notification.root.params.requestId + if cancelled_id in self._in_flight: + await self._in_flight[cancelled_id].cancel() + else: + # Handle progress notifications callback + if isinstance(notification.root, ProgressNotification): + progress_token = ( + notification.root.params.progressToken ) - await self._received_notification(notification) - await self._handle_incoming(notification) - except Exception as e: - # For other validation errors, log and continue - logging.warning( - f"Failed to validate notification: {e}. " - f"Message was: {message.message.root}" - ) - else: # Response or error - stream = self._response_streams.pop(message.message.root.id, None) - if stream: - await stream.send(message.message.root) - else: - await self._handle_incoming( - RuntimeError( - "Received response with an unknown " - f"request ID: {message}" + # If there is a progress callback for this token, + # call it with the progress information + if progress_token in self._progress_callbacks: + callback = self._progress_callbacks[ + progress_token + ] + await callback( + notification.root.params.progress, + notification.root.params.total, + notification.root.params.message, + ) + await self._received_notification(notification) + await self._handle_incoming(notification) + except Exception as e: + # For other validation errors, log and continue + logging.warning( + f"Failed to validate notification: {e}. " + f"Message was: {message.message.root}" ) + else: # Response or error + stream = self._response_streams.pop( + message.message.root.id, None ) + if stream: + await stream.send(message.message.root) + else: + await self._handle_incoming( + RuntimeError( + "Received response with an unknown " + f"request ID: {message}" + ) + ) # after the read stream is closed, we need to send errors # to any pending requests diff --git a/tests/client/test_sampling_callback.py b/tests/client/test_sampling_callback.py index ba586d4a8..df083f1e5 100644 --- a/tests/client/test_sampling_callback.py +++ b/tests/client/test_sampling_callback.py @@ -1,3 +1,4 @@ +import anyio import pytest from mcp.client.session import ClientSession @@ -71,3 +72,105 @@ async def test_sampling_tool(message: str): result.content[0].text == "Error executing tool test_sampling: Sampling not supported" ) + + +@pytest.mark.anyio +async def test_concurrent_sampling_callback(): + """Test multiple concurrent sampling calls using time-sort verification.""" + from mcp.server.fastmcp import FastMCP + + server = FastMCP("test") + + # Track completion order using time-sort approach + completion_order = [] + + async def sampling_callback( + context: RequestContext[ClientSession, None], + params: CreateMessageRequestParams, + ) -> CreateMessageResult: + # Extract delay from the message content (e.g., "delay_0.3") + assert isinstance(params.messages[0].content, TextContent) + message_text = params.messages[0].content.text + if message_text.startswith("delay_"): + delay = float(message_text.split("_")[1]) + # Simulate different LLM response times + await anyio.sleep(delay) + completion_order.append(delay) + return CreateMessageResult( + role="assistant", + content=TextContent(type="text", text=f"Response after {delay}s"), + model="test-model", + stopReason="endTurn", + ) + + return CreateMessageResult( + role="assistant", + content=TextContent(type="text", text="Default response"), + model="test-model", + stopReason="endTurn", + ) + + @server.tool("concurrent_sampling_tool") + async def concurrent_sampling_tool(): + """Tool that makes multiple concurrent sampling calls.""" + # Use TaskGroup to make multiple concurrent sampling calls + # Using out-of-order durations: 0.6s, 0.2s, 0.4s + # If concurrent, should complete in order: 0.2s, 0.4s, 0.6s + async with anyio.create_task_group() as tg: + results = {} + + async def make_sampling_call(call_id: str, delay: float): + result = await server.get_context().session.create_message( + messages=[ + SamplingMessage( + role="user", + content=TextContent(type="text", text=f"delay_{delay}"), + ) + ], + max_tokens=100, + ) + results[call_id] = result + + # Start operations with out-of-order timing + tg.start_soon(make_sampling_call, "slow_call", 0.6) # Should finish last + tg.start_soon(make_sampling_call, "fast_call", 0.2) # Should finish first + tg.start_soon( + make_sampling_call, "medium_call", 0.4 + ) # Should finish middle + + # Combine results to show all completed + combined_response = " | ".join( + [ + results["slow_call"].content.text, + results["fast_call"].content.text, + results["medium_call"].content.text, + ] + ) + + return combined_response + + # Test concurrent sampling calls with time-sort verification + async with create_session( + server._mcp_server, sampling_callback=sampling_callback + ) as client_session: + # Make a request that triggers multiple concurrent sampling calls + result = await client_session.call_tool("concurrent_sampling_tool", {}) + + assert result.isError is False + assert isinstance(result.content[0], TextContent) + + # Verify all sampling calls completed with expected responses + expected_result = ( + "Response after 0.6s | Response after 0.2s | Response after 0.4s" + ) + assert result.content[0].text == expected_result + + # Key test: verify concurrent execution using time-sort + # Started in order: 0.6s, 0.2s, 0.4s + # Should complete in order: 0.2s, 0.4s, 0.6s (fastest first) + assert len(completion_order) == 3 + assert completion_order == [ + 0.2, + 0.4, + 0.6, + ], f"Expected [0.2, 0.4, 0.6] but got {completion_order}" diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index eb4e004ae..4361f951b 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -17,6 +17,7 @@ ClientNotification, ClientRequest, EmptyResult, + TextContent, ) @@ -181,3 +182,78 @@ async def mock_server(): await ev_closed.wait() with anyio.fail_after(1): await ev_response.wait() + + +@pytest.mark.anyio +async def test_async_request_handling_with_taskgroup(): + """Test that multiple sampling requests are handled asynchronously.""" + # Track completion order + completion_order = [] + + def make_server() -> Server: + server = Server(name="AsyncTestServer") + + @server.call_tool() + async def handle_call_tool(name: str, arguments: dict | None) -> list: + nonlocal completion_order + + if name.startswith("timed_tool"): + # Extract wait time from tool name (e.g., "timed_tool_0.2") + wait_time = float(name.split("_")[-1]) + + # Wait for the specified time + await anyio.sleep(wait_time) + + # Record completion + completion_order.append(wait_time) + + return [TextContent(type="text", text=f"Waited {wait_time}s")] + + raise ValueError(f"Unknown tool: {name}") + + @server.list_tools() + async def handle_list_tools() -> list[types.Tool]: + return [ + types.Tool( + name="timed_tool_0.1", + description="Tool that waits 0.1s", + inputSchema={}, + ), + types.Tool( + name="timed_tool_0.2", + description="Tool that waits 0.2s", + inputSchema={}, + ), + types.Tool( + name="timed_tool_0.05", + description="Tool that waits 0.05s", + inputSchema={}, + ), + ] + + return server + + async with create_connected_server_and_client_session( + make_server() + ) as client_session: + # Test basic async handling with a single request + result = await client_session.send_request( + ClientRequest( + types.CallToolRequest( + method="tools/call", + params=types.CallToolRequestParams( + name="timed_tool_0.1", arguments={} + ), + ) + ), + types.CallToolResult, + ) + + # Verify the request completed successfully + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "Waited 0.1s" + assert len(completion_order) == 1 + assert completion_order[0] == 0.1 + + # Verify no pending requests remain + assert len(client_session._in_flight) == 0 diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 5cf346e1a..5d9e0d97e 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -172,7 +172,9 @@ async def handle_call_tool(name: str, args: dict) -> list[TextContent]: related_request_id=ctx.request_id, # need for stream association ) - await anyio.sleep(0.1) + # need to wait for long enough that the client can + # reliably stop the tool before this finishes + await anyio.sleep(0.3) await ctx.session.send_log_message( level="info", @@ -181,6 +183,18 @@ async def handle_call_tool(name: str, args: dict) -> list[TextContent]: related_request_id=ctx.request_id, ) + # Adding another message just to make it even less + # likely that this tool will exit before the client + # can stop it + await anyio.sleep(0.3) + + await ctx.session.send_log_message( + level="info", + data="Tool is done", + logger="tool", + related_request_id=ctx.request_id, + ) + return [TextContent(type="text", text="Completed!")] elif name == "test_sampling_tool": @@ -1099,6 +1113,11 @@ async def run_tool(): await anyio.sleep(0.1) tg.cancel_scope.cancel() + # Make sure we only have one notification.. otherwise the test is flaky + # More than one notification means the tool likely could have finished + # already and will not call the message handler again upon resumption + assert len(captured_notifications) == 1 + # Store pre notifications and clear the captured notifications # for the post-resumption check captured_notifications_pre = captured_notifications.copy() @@ -1125,6 +1144,10 @@ async def run_tool(): metadata = ClientMessageMetadata( resumption_token=captured_resumption_token, ) + # We need to wait for the tool to send another message so this doesn't + # deadlock. Fixing is out of scope for this PR. More details in + # https://github.com/modelcontextprotocol/python-sdk/issues/860 + await anyio.sleep(0.2) result = await session.send_request( types.ClientRequest( types.CallToolRequest( @@ -1144,7 +1167,7 @@ async def run_tool(): assert "Completed" in result.content[0].text # We should have received the remaining notifications - assert len(captured_notifications) > 0 + assert len(captured_notifications) == 2 # Should not have the first notification # Check that "Tool started" notification isn't repeated when resuming