Skip to content

Close unclosed resources in the whole project #267

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Mar 13, 2025
10 changes: 10 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ packages = ["src/mcp"]
include = ["src/mcp", "tests"]
venvPath = "."
venv = ".venv"
strict = [
"src/mcp/server/fastmcp/tools/base.py",
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More should be added here.

]

[tool.ruff.lint]
select = ["E", "F", "I"]
Expand All @@ -85,3 +88,10 @@ members = ["examples/servers/*"]

[tool.uv.sources]
mcp = { workspace = true }

# TODO(Marcelo): This should be enabled!!! There are a lot of resource warnings.
[tool.pytest.ini_options]
xfail_strict = true
filterwarnings = [
"error",
]
10 changes: 7 additions & 3 deletions src/mcp/client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ async def _default_list_roots_callback(
)


ClientResponse = TypeAdapter(types.ClientResult | types.ErrorData)
ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(
types.ClientResult | types.ErrorData
)


class ClientSession(
Expand Down Expand Up @@ -219,7 +221,7 @@ async def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
)

async def call_tool(
self, name: str, arguments: dict | None = None
self, name: str, arguments: dict[str, Any] | None = None
) -> types.CallToolResult:
"""Send a tools/call request."""
return await self.send_request(
Expand Down Expand Up @@ -258,7 +260,9 @@ async def get_prompt(
)

async def complete(
self, ref: types.ResourceReference | types.PromptReference, argument: dict
self,
ref: types.ResourceReference | types.PromptReference,
argument: dict[str, str],
) -> types.CompleteResult:
"""Send a completion/complete request."""
return await self.send_request(
Expand Down
6 changes: 3 additions & 3 deletions src/mcp/server/fastmcp/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
class Tool(BaseModel):
"""Internal tool registration info."""

fn: Callable = Field(exclude=True)
fn: Callable[..., Any] = Field(exclude=True)
name: str = Field(description="Name of the tool")
description: str = Field(description="Description of what the tool does")
parameters: dict = Field(description="JSON schema for tool parameters")
parameters: dict[str, Any] = Field(description="JSON schema for tool parameters")
fn_metadata: FuncMetadata = Field(
description="Metadata about the function including a pydantic model for tool"
" arguments"
Expand All @@ -34,7 +34,7 @@ class Tool(BaseModel):
@classmethod
def from_function(
cls,
fn: Callable,
fn: Callable[..., Any],
name: str | None = None,
description: str | None = None,
context_kwarg: str | None = None,
Expand Down
4 changes: 3 additions & 1 deletion src/mcp/server/fastmcp/utilities/func_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,9 @@ def pre_parse_json(self, data: dict[str, Any]) -> dict[str, Any]:
)


def func_metadata(func: Callable, skip_names: Sequence[str] = ()) -> FuncMetadata:
def func_metadata(
func: Callable[..., Any], skip_names: Sequence[str] = ()
) -> FuncMetadata:
"""Given a function, return metadata including a pydantic model representing its
signature.

Expand Down
20 changes: 16 additions & 4 deletions src/mcp/shared/session.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from contextlib import AsyncExitStack
from datetime import timedelta
from typing import Any, Callable, Generic, TypeVar

Expand Down Expand Up @@ -180,13 +181,20 @@ def __init__(
self._read_timeout_seconds = read_timeout_seconds
self._in_flight = {}

self._exit_stack = AsyncExitStack()
self._incoming_message_stream_writer, self._incoming_message_stream_reader = (
anyio.create_memory_object_stream[
RequestResponder[ReceiveRequestT, SendResultT]
| ReceiveNotificationT
| Exception
]()
)
self._exit_stack.push_async_callback(
lambda: self._incoming_message_stream_reader.aclose()
)
self._exit_stack.push_async_callback(
lambda: self._incoming_message_stream_writer.aclose()
)

async def __aenter__(self) -> Self:
self._task_group = anyio.create_task_group()
Expand All @@ -195,6 +203,7 @@ async def __aenter__(self) -> Self:
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
await self._exit_stack.aclose()
# Using BaseSession as a context manager should not block on exit (this
# would be very surprising behavior), so make sure to cancel the tasks
# in the task group.
Expand Down Expand Up @@ -222,6 +231,9 @@ async def send_request(
](1)
self._response_streams[request_id] = response_stream

self._exit_stack.push_async_callback(lambda: response_stream.aclose())
self._exit_stack.push_async_callback(lambda: response_stream_reader.aclose())

jsonrpc_request = JSONRPCRequest(
jsonrpc="2.0",
id=request_id,
Expand Down Expand Up @@ -250,11 +262,11 @@ async def send_request(
),
)
)

if isinstance(response_or_error, JSONRPCError):
raise McpError(response_or_error.error)
else:
return result_type.model_validate(response_or_error.result)
if isinstance(response_or_error, JSONRPCError):
raise McpError(response_or_error.error)
else:
return result_type.model_validate(response_or_error.result)

async def send_notification(self, notification: SendNotificationT) -> None:
"""
Expand Down
4 changes: 4 additions & 0 deletions tests/client/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ async def listen_session():
async with (
ClientSession(server_to_client_receive, client_to_server_send) as session,
anyio.create_task_group() as tg,
client_to_server_send,
client_to_server_receive,
server_to_client_send,
server_to_client_receive,
):
tg.start_soon(mock_server)
tg.start_soon(listen_session)
Expand Down
8 changes: 7 additions & 1 deletion tests/issues/test_192_request_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,13 @@ async def run_server():
)

# Start server task
async with anyio.create_task_group() as tg:
async with (
anyio.create_task_group() as tg,
client_writer,
client_reader,
server_writer,
server_reader,
):
tg.start_soon(run_server)

# Send initialize request
Expand Down
18 changes: 15 additions & 3 deletions tests/server/test_lifespan.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ async def test_lowlevel_server_lifespan():
"""Test that lifespan works in low-level server."""

@asynccontextmanager
async def test_lifespan(server: Server) -> AsyncIterator[dict]:
async def test_lifespan(server: Server) -> AsyncIterator[dict[str, bool]]:
"""Test lifespan context that tracks startup/shutdown."""
context = {"started": False, "shutdown": False}
try:
Expand All @@ -50,7 +50,13 @@ async def check_lifespan(name: str, arguments: dict) -> list:
return [{"type": "text", "text": "true"}]

# Run server in background task
async with anyio.create_task_group() as tg:
async with (
anyio.create_task_group() as tg,
send_stream1,
receive_stream1,
send_stream2,
receive_stream2,
):

async def run_server():
await server.run(
Expand Down Expand Up @@ -147,7 +153,13 @@ def check_lifespan(ctx: Context) -> bool:
return True

# Run server in background task
async with anyio.create_task_group() as tg:
async with (
anyio.create_task_group() as tg,
send_stream1,
receive_stream1,
send_stream2,
receive_stream2,
):

async def run_server():
await server._mcp_server.run(
Expand Down
Loading