|
7 | 7 | from mcp.client.session import ClientSession
|
8 | 8 | from mcp.server.lowlevel.server import Server
|
9 | 9 | from mcp.shared.exceptions import McpError
|
10 |
| -from mcp.shared.memory import create_connected_server_and_client_session |
| 10 | +from mcp.shared.memory import ( |
| 11 | + create_client_server_memory_streams, |
| 12 | + create_connected_server_and_client_session, |
| 13 | +) |
11 | 14 | from mcp.types import (
|
12 | 15 | CancelledNotification,
|
13 | 16 | CancelledNotificationParams,
|
@@ -124,3 +127,55 @@ async def make_request(client_session):
|
124 | 127 | # Give cancellation time to process
|
125 | 128 | with anyio.fail_after(1):
|
126 | 129 | await ev_cancelled.wait()
|
| 130 | + |
| 131 | + |
| 132 | +@pytest.mark.anyio |
| 133 | +async def test_connection_closed(): |
| 134 | + """Test that pending requests are cancelled when the connection is closed remotely.""" |
| 135 | + |
| 136 | + ev_closed = anyio.Event() |
| 137 | + ev_response = anyio.Event() |
| 138 | + |
| 139 | + async with create_client_server_memory_streams() as ( |
| 140 | + client_streams, |
| 141 | + server_streams, |
| 142 | + ): |
| 143 | + client_read, client_write = client_streams |
| 144 | + server_read, server_write = server_streams |
| 145 | + |
| 146 | + async def make_request(client_session): |
| 147 | + """Send a request in a separate task""" |
| 148 | + nonlocal ev_response |
| 149 | + try: |
| 150 | + # any request will do |
| 151 | + await client_session.initialize() |
| 152 | + pytest.fail("Request should have errored") |
| 153 | + except McpError as e: |
| 154 | + # Expected - request errored |
| 155 | + assert "Connection closed" in str(e) |
| 156 | + ev_response.set() |
| 157 | + |
| 158 | + async def mock_server(): |
| 159 | + """Wait for a request, then close the connection""" |
| 160 | + nonlocal ev_closed |
| 161 | + # Wait for a request |
| 162 | + await server_read.receive() |
| 163 | + # Close the connection, as if the server exited |
| 164 | + server_write.close() |
| 165 | + server_read.close() |
| 166 | + ev_closed.set() |
| 167 | + |
| 168 | + async with ( |
| 169 | + anyio.create_task_group() as tg, |
| 170 | + ClientSession( |
| 171 | + read_stream=client_read, |
| 172 | + write_stream=client_write, |
| 173 | + ) as client_session, |
| 174 | + ): |
| 175 | + tg.start_soon(make_request, client_session) |
| 176 | + tg.start_soon(mock_server) |
| 177 | + |
| 178 | + with anyio.fail_after(1): |
| 179 | + await ev_closed.wait() |
| 180 | + with anyio.fail_after(1): |
| 181 | + await ev_response.wait() |
0 commit comments