Skip to content

Feature: Async handling of sampling calls #840

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 75 additions & 62 deletions src/mcp/shared/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,76 +350,89 @@ async def _receive_loop(self) -> None:
self._read_stream,
self._write_stream,
):
async for message in self._read_stream:
if isinstance(message, Exception):
await self._handle_incoming(message)
elif isinstance(message.message.root, JSONRPCRequest):
validated_request = self._receive_request_type.model_validate(
message.message.root.model_dump(
by_alias=True, mode="json", exclude_none=True
async with anyio.create_task_group() as tg:
async for message in self._read_stream:
if isinstance(message, Exception):
await self._handle_incoming(message)
elif isinstance(message.message.root, JSONRPCRequest):
validated_request = self._receive_request_type.model_validate(
message.message.root.model_dump(
by_alias=True, mode="json", exclude_none=True
)
)
responder = RequestResponder(
request_id=message.message.root.id,
request_meta=validated_request.root.params.meta
if validated_request.root.params
else None,
request=validated_request,
session=self,
on_complete=lambda r: self._in_flight.pop(
r.request_id, None
),
message_metadata=message.metadata,
)
)
responder = RequestResponder(
request_id=message.message.root.id,
request_meta=validated_request.root.params.meta
if validated_request.root.params
else None,
request=validated_request,
session=self,
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
message_metadata=message.metadata,
)

self._in_flight[responder.request_id] = responder
await self._received_request(responder)
async def _handle_received_request() -> None:
await self._received_request(responder)
if not responder._completed: # type: ignore[reportPrivateUsage]
await self._handle_incoming(responder)

if not responder._completed: # type: ignore[reportPrivateUsage]
await self._handle_incoming(responder)
self._in_flight[responder.request_id] = responder
tg.start_soon(_handle_received_request)

elif isinstance(message.message.root, JSONRPCNotification):
try:
notification = self._receive_notification_type.model_validate(
message.message.root.model_dump(
by_alias=True, mode="json", exclude_none=True
elif isinstance(message.message.root, JSONRPCNotification):
try:
notification = (
self._receive_notification_type.model_validate(
message.message.root.model_dump(
by_alias=True, mode="json", exclude_none=True
)
)
)
)
# Handle cancellation notifications
if isinstance(notification.root, CancelledNotification):
cancelled_id = notification.root.params.requestId
if cancelled_id in self._in_flight:
await self._in_flight[cancelled_id].cancel()
else:
# Handle progress notifications callback
if isinstance(notification.root, ProgressNotification):
progress_token = notification.root.params.progressToken
# If there is a progress callback for this token,
# call it with the progress information
if progress_token in self._progress_callbacks:
callback = self._progress_callbacks[progress_token]
await callback(
notification.root.params.progress,
notification.root.params.total,
notification.root.params.message,
# Handle cancellation notifications
if isinstance(notification.root, CancelledNotification):
cancelled_id = notification.root.params.requestId
if cancelled_id in self._in_flight:
await self._in_flight[cancelled_id].cancel()
else:
# Handle progress notifications callback
if isinstance(notification.root, ProgressNotification):
progress_token = (
notification.root.params.progressToken
)
await self._received_notification(notification)
await self._handle_incoming(notification)
except Exception as e:
# For other validation errors, log and continue
logging.warning(
f"Failed to validate notification: {e}. "
f"Message was: {message.message.root}"
)
else: # Response or error
stream = self._response_streams.pop(message.message.root.id, None)
if stream:
await stream.send(message.message.root)
else:
await self._handle_incoming(
RuntimeError(
"Received response with an unknown "
f"request ID: {message}"
# If there is a progress callback for this token,
# call it with the progress information
if progress_token in self._progress_callbacks:
callback = self._progress_callbacks[
progress_token
]
await callback(
notification.root.params.progress,
notification.root.params.total,
notification.root.params.message,
)
await self._received_notification(notification)
await self._handle_incoming(notification)
except Exception as e:
# For other validation errors, log and continue
logging.warning(
f"Failed to validate notification: {e}. "
f"Message was: {message.message.root}"
)
else: # Response or error
stream = self._response_streams.pop(
message.message.root.id, None
)
if stream:
await stream.send(message.message.root)
else:
await self._handle_incoming(
RuntimeError(
"Received response with an unknown "
f"request ID: {message}"
)
)

# after the read stream is closed, we need to send errors
# to any pending requests
Expand Down
103 changes: 103 additions & 0 deletions tests/client/test_sampling_callback.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import anyio
import pytest

from mcp.client.session import ClientSession
Expand Down Expand Up @@ -71,3 +72,105 @@ async def test_sampling_tool(message: str):
result.content[0].text
== "Error executing tool test_sampling: Sampling not supported"
)


@pytest.mark.anyio
async def test_concurrent_sampling_callback():
"""Test multiple concurrent sampling calls using time-sort verification."""
from mcp.server.fastmcp import FastMCP

server = FastMCP("test")

# Track completion order using time-sort approach
completion_order = []

async def sampling_callback(
context: RequestContext[ClientSession, None],
params: CreateMessageRequestParams,
) -> CreateMessageResult:
# Extract delay from the message content (e.g., "delay_0.3")
assert isinstance(params.messages[0].content, TextContent)
message_text = params.messages[0].content.text
if message_text.startswith("delay_"):
delay = float(message_text.split("_")[1])
# Simulate different LLM response times
await anyio.sleep(delay)
completion_order.append(delay)
return CreateMessageResult(
role="assistant",
content=TextContent(type="text", text=f"Response after {delay}s"),
model="test-model",
stopReason="endTurn",
)

return CreateMessageResult(
role="assistant",
content=TextContent(type="text", text="Default response"),
model="test-model",
stopReason="endTurn",
)

@server.tool("concurrent_sampling_tool")
async def concurrent_sampling_tool():
"""Tool that makes multiple concurrent sampling calls."""
# Use TaskGroup to make multiple concurrent sampling calls
# Using out-of-order durations: 0.6s, 0.2s, 0.4s
# If concurrent, should complete in order: 0.2s, 0.4s, 0.6s
async with anyio.create_task_group() as tg:
results = {}

async def make_sampling_call(call_id: str, delay: float):
result = await server.get_context().session.create_message(
messages=[
SamplingMessage(
role="user",
content=TextContent(type="text", text=f"delay_{delay}"),
)
],
max_tokens=100,
)
results[call_id] = result

# Start operations with out-of-order timing
tg.start_soon(make_sampling_call, "slow_call", 0.6) # Should finish last
tg.start_soon(make_sampling_call, "fast_call", 0.2) # Should finish first
tg.start_soon(
make_sampling_call, "medium_call", 0.4
) # Should finish middle

# Combine results to show all completed
combined_response = " | ".join(
[
results["slow_call"].content.text,
results["fast_call"].content.text,
results["medium_call"].content.text,
]
)

return combined_response

# Test concurrent sampling calls with time-sort verification
async with create_session(
server._mcp_server, sampling_callback=sampling_callback
) as client_session:
# Make a request that triggers multiple concurrent sampling calls
result = await client_session.call_tool("concurrent_sampling_tool", {})

assert result.isError is False
assert isinstance(result.content[0], TextContent)

# Verify all sampling calls completed with expected responses
expected_result = (
"Response after 0.6s | Response after 0.2s | Response after 0.4s"
)
assert result.content[0].text == expected_result

# Key test: verify concurrent execution using time-sort
# Started in order: 0.6s, 0.2s, 0.4s
# Should complete in order: 0.2s, 0.4s, 0.6s (fastest first)
assert len(completion_order) == 3
assert completion_order == [
0.2,
0.4,
0.6,
], f"Expected [0.2, 0.4, 0.6] but got {completion_order}"
76 changes: 76 additions & 0 deletions tests/shared/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
ClientNotification,
ClientRequest,
EmptyResult,
TextContent,
)


Expand Down Expand Up @@ -181,3 +182,78 @@ async def mock_server():
await ev_closed.wait()
with anyio.fail_after(1):
await ev_response.wait()


@pytest.mark.anyio
async def test_async_request_handling_with_taskgroup():
"""Test that multiple sampling requests are handled asynchronously."""
# Track completion order
completion_order = []

def make_server() -> Server:
server = Server(name="AsyncTestServer")

@server.call_tool()
async def handle_call_tool(name: str, arguments: dict | None) -> list:
nonlocal completion_order

if name.startswith("timed_tool"):
# Extract wait time from tool name (e.g., "timed_tool_0.2")
wait_time = float(name.split("_")[-1])

# Wait for the specified time
await anyio.sleep(wait_time)

# Record completion
completion_order.append(wait_time)

return [TextContent(type="text", text=f"Waited {wait_time}s")]

raise ValueError(f"Unknown tool: {name}")

@server.list_tools()
async def handle_list_tools() -> list[types.Tool]:
return [
types.Tool(
name="timed_tool_0.1",
description="Tool that waits 0.1s",
inputSchema={},
),
types.Tool(
name="timed_tool_0.2",
description="Tool that waits 0.2s",
inputSchema={},
),
types.Tool(
name="timed_tool_0.05",
description="Tool that waits 0.05s",
inputSchema={},
),
]

return server

async with create_connected_server_and_client_session(
make_server()
) as client_session:
# Test basic async handling with a single request
result = await client_session.send_request(
ClientRequest(
types.CallToolRequest(
method="tools/call",
params=types.CallToolRequestParams(
name="timed_tool_0.1", arguments={}
),
)
),
types.CallToolResult,
)

# Verify the request completed successfully
assert isinstance(result.content[0], TextContent)
assert result.content[0].text == "Waited 0.1s"
assert len(completion_order) == 1
assert completion_order[0] == 0.1

# Verify no pending requests remain
assert len(client_session._in_flight) == 0
Loading
Loading