diff --git a/pyproject.toml b/pyproject.toml index 956d9c8c..157263de 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,6 +67,9 @@ packages = ["src/mcp"] include = ["src/mcp", "tests"] venvPath = "." venv = ".venv" +strict = [ + "src/mcp/server/fastmcp/tools/base.py", +] [tool.ruff.lint] select = ["E", "F", "I"] @@ -85,3 +88,13 @@ members = ["examples/servers/*"] [tool.uv.sources] mcp = { workspace = true } + +[tool.pytest.ini_options] +xfail_strict = true +filterwarnings = [ + "error", + # This should be fixed on Uvicorn's side. + "ignore::DeprecationWarning:websockets", + "ignore:websockets.server.WebSocketServerProtocol is deprecated:DeprecationWarning", + "ignore:Returning str or bytes.*:DeprecationWarning:mcp.server.lowlevel" +] diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index c1cc5b5f..cde3103b 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -43,7 +43,9 @@ async def _default_list_roots_callback( ) -ClientResponse = TypeAdapter(types.ClientResult | types.ErrorData) +ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter( + types.ClientResult | types.ErrorData +) class ClientSession( @@ -219,7 +221,7 @@ async def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult: ) async def call_tool( - self, name: str, arguments: dict | None = None + self, name: str, arguments: dict[str, Any] | None = None ) -> types.CallToolResult: """Send a tools/call request.""" return await self.send_request( @@ -258,7 +260,9 @@ async def get_prompt( ) async def complete( - self, ref: types.ResourceReference | types.PromptReference, argument: dict + self, + ref: types.ResourceReference | types.PromptReference, + argument: dict[str, str], ) -> types.CompleteResult: """Send a completion/complete request.""" return await self.send_request( diff --git a/src/mcp/server/fastmcp/tools/base.py b/src/mcp/server/fastmcp/tools/base.py index da5d9348..bf68dc02 100644 --- a/src/mcp/server/fastmcp/tools/base.py +++ b/src/mcp/server/fastmcp/tools/base.py @@ -18,10 +18,10 @@ class Tool(BaseModel): """Internal tool registration info.""" - fn: Callable = Field(exclude=True) + fn: Callable[..., Any] = Field(exclude=True) name: str = Field(description="Name of the tool") description: str = Field(description="Description of what the tool does") - parameters: dict = Field(description="JSON schema for tool parameters") + parameters: dict[str, Any] = Field(description="JSON schema for tool parameters") fn_metadata: FuncMetadata = Field( description="Metadata about the function including a pydantic model for tool" " arguments" @@ -34,7 +34,7 @@ class Tool(BaseModel): @classmethod def from_function( cls, - fn: Callable, + fn: Callable[..., Any], name: str | None = None, description: str | None = None, context_kwarg: str | None = None, diff --git a/src/mcp/server/fastmcp/utilities/func_metadata.py b/src/mcp/server/fastmcp/utilities/func_metadata.py index cf93049e..7bcc9baf 100644 --- a/src/mcp/server/fastmcp/utilities/func_metadata.py +++ b/src/mcp/server/fastmcp/utilities/func_metadata.py @@ -102,7 +102,9 @@ def pre_parse_json(self, data: dict[str, Any]) -> dict[str, Any]: ) -def func_metadata(func: Callable, skip_names: Sequence[str] = ()) -> FuncMetadata: +def func_metadata( + func: Callable[..., Any], skip_names: Sequence[str] = () +) -> FuncMetadata: """Given a function, return metadata including a pydantic model representing its signature. diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index d0dcaee8..31f88824 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -1,4 +1,5 @@ import logging +from contextlib import AsyncExitStack from datetime import timedelta from typing import Any, Callable, Generic, TypeVar @@ -180,6 +181,7 @@ def __init__( self._read_timeout_seconds = read_timeout_seconds self._in_flight = {} + self._exit_stack = AsyncExitStack() self._incoming_message_stream_writer, self._incoming_message_stream_reader = ( anyio.create_memory_object_stream[ RequestResponder[ReceiveRequestT, SendResultT] @@ -187,6 +189,12 @@ def __init__( | Exception ]() ) + self._exit_stack.push_async_callback( + lambda: self._incoming_message_stream_reader.aclose() + ) + self._exit_stack.push_async_callback( + lambda: self._incoming_message_stream_writer.aclose() + ) async def __aenter__(self) -> Self: self._task_group = anyio.create_task_group() @@ -195,6 +203,7 @@ async def __aenter__(self) -> Self: return self async def __aexit__(self, exc_type, exc_val, exc_tb): + await self._exit_stack.aclose() # Using BaseSession as a context manager should not block on exit (this # would be very surprising behavior), so make sure to cancel the tasks # in the task group. @@ -222,6 +231,9 @@ async def send_request( ](1) self._response_streams[request_id] = response_stream + self._exit_stack.push_async_callback(lambda: response_stream.aclose()) + self._exit_stack.push_async_callback(lambda: response_stream_reader.aclose()) + jsonrpc_request = JSONRPCRequest( jsonrpc="2.0", id=request_id, diff --git a/tests/client/test_session.py b/tests/client/test_session.py index 90de898c..7d579cda 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -83,6 +83,10 @@ async def listen_session(): async with ( ClientSession(server_to_client_receive, client_to_server_send) as session, anyio.create_task_group() as tg, + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, ): tg.start_soon(mock_server) tg.start_soon(listen_session) diff --git a/tests/issues/test_192_request_id.py b/tests/issues/test_192_request_id.py index 628f00f9..00e18789 100644 --- a/tests/issues/test_192_request_id.py +++ b/tests/issues/test_192_request_id.py @@ -43,7 +43,13 @@ async def run_server(): ) # Start server task - async with anyio.create_task_group() as tg: + async with ( + anyio.create_task_group() as tg, + client_writer, + client_reader, + server_writer, + server_reader, + ): tg.start_soon(run_server) # Send initialize request diff --git a/tests/server/test_lifespan.py b/tests/server/test_lifespan.py index 14afb6b0..37a52969 100644 --- a/tests/server/test_lifespan.py +++ b/tests/server/test_lifespan.py @@ -25,7 +25,7 @@ async def test_lowlevel_server_lifespan(): """Test that lifespan works in low-level server.""" @asynccontextmanager - async def test_lifespan(server: Server) -> AsyncIterator[dict]: + async def test_lifespan(server: Server) -> AsyncIterator[dict[str, bool]]: """Test lifespan context that tracks startup/shutdown.""" context = {"started": False, "shutdown": False} try: @@ -50,7 +50,13 @@ async def check_lifespan(name: str, arguments: dict) -> list: return [{"type": "text", "text": "true"}] # Run server in background task - async with anyio.create_task_group() as tg: + async with ( + anyio.create_task_group() as tg, + send_stream1, + receive_stream1, + send_stream2, + receive_stream2, + ): async def run_server(): await server.run( @@ -147,7 +153,13 @@ def check_lifespan(ctx: Context) -> bool: return True # Run server in background task - async with anyio.create_task_group() as tg: + async with ( + anyio.create_task_group() as tg, + send_stream1, + receive_stream1, + send_stream2, + receive_stream2, + ): async def run_server(): await server._mcp_server.run(