diff --git a/README.md b/README.md index 84359984..60b5a726 100644 --- a/README.md +++ b/README.md @@ -16,33 +16,36 @@ ## Table of Contents -- [Overview](#overview) -- [Installation](#installation) -- [Quickstart](#quickstart) -- [What is MCP?](#what-is-mcp) -- [Core Concepts](#core-concepts) - - [Server](#server) - - [Resources](#resources) - - [Tools](#tools) - - [Prompts](#prompts) - - [Images](#images) - - [Context](#context) -- [Running Your Server](#running-your-server) - - [Development Mode](#development-mode) - - [Claude Desktop Integration](#claude-desktop-integration) - - [Direct Execution](#direct-execution) - - [Mounting to an Existing ASGI Server](#mounting-to-an-existing-asgi-server) -- [Examples](#examples) - - [Echo Server](#echo-server) - - [SQLite Explorer](#sqlite-explorer) -- [Advanced Usage](#advanced-usage) - - [Low-Level Server](#low-level-server) - - [Writing MCP Clients](#writing-mcp-clients) - - [MCP Primitives](#mcp-primitives) - - [Server Capabilities](#server-capabilities) -- [Documentation](#documentation) -- [Contributing](#contributing) -- [License](#license) +- [MCP Python SDK](#mcp-python-sdk) + - [Overview](#overview) + - [Installation](#installation) + - [Adding MCP to your python project](#adding-mcp-to-your-python-project) + - [Running the standalone MCP development tools](#running-the-standalone-mcp-development-tools) + - [Quickstart](#quickstart) + - [What is MCP?](#what-is-mcp) + - [Core Concepts](#core-concepts) + - [Server](#server) + - [Resources](#resources) + - [Tools](#tools) + - [Prompts](#prompts) + - [Images](#images) + - [Context](#context) + - [Running Your Server](#running-your-server) + - [Development Mode](#development-mode) + - [Claude Desktop Integration](#claude-desktop-integration) + - [Direct Execution](#direct-execution) + - [Mounting to an Existing ASGI Server](#mounting-to-an-existing-asgi-server) + - [Examples](#examples) + - [Echo Server](#echo-server) + - [SQLite Explorer](#sqlite-explorer) + - [Advanced Usage](#advanced-usage) + - [Low-Level Server](#low-level-server) + - [Writing MCP Clients](#writing-mcp-clients) + - [MCP Primitives](#mcp-primitives) + - [Server Capabilities](#server-capabilities) + - [Documentation](#documentation) + - [Contributing](#contributing) + - [License](#license) [pypi-badge]: https://img.shields.io/pypi/v/mcp.svg [pypi-url]: https://pypi.org/project/mcp/ @@ -143,8 +146,8 @@ The FastMCP server is your core interface to the MCP protocol. It handles connec ```python # Add lifespan support for startup/shutdown with strong typing from contextlib import asynccontextmanager +from collections.abc import AsyncIterator from dataclasses import dataclass -from typing import AsyncIterator from fake_database import Database # Replace with your actual DB type @@ -442,7 +445,7 @@ For more control, you can use the low-level server implementation directly. This ```python from contextlib import asynccontextmanager -from typing import AsyncIterator +from collections.abc import AsyncIterator from fake_database import Database # Replace with your actual DB type diff --git a/pyproject.toml b/pyproject.toml index 046e90a4..e400ad7d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,13 +76,11 @@ packages = ["src/mcp"] include = ["src/mcp", "tests"] venvPath = "." venv = ".venv" -strict = [ - "src/mcp/server/fastmcp/tools/base.py", - "src/mcp/client/*.py" -] +strict = ["src/mcp/**/*.py"] +exclude = ["src/mcp/types.py"] [tool.ruff.lint] -select = ["E", "F", "I"] +select = ["E", "F", "I", "UP"] ignore = [] [tool.ruff] diff --git a/src/mcp/cli/claude.py b/src/mcp/cli/claude.py index fe3f3380..5a0ce0ab 100644 --- a/src/mcp/cli/claude.py +++ b/src/mcp/cli/claude.py @@ -4,6 +4,7 @@ import os import sys from pathlib import Path +from typing import Any from mcp.server.fastmcp.utilities.logging import get_logger @@ -116,10 +117,7 @@ def update_claude_config( # Add fastmcp run command args.extend(["mcp", "run", file_spec]) - server_config = { - "command": "uv", - "args": args, - } + server_config: dict[str, Any] = {"command": "uv", "args": args} # Add environment variables if specified if env_vars: diff --git a/src/mcp/client/websocket.py b/src/mcp/client/websocket.py index 9cf32296..2c2ed38b 100644 --- a/src/mcp/client/websocket.py +++ b/src/mcp/client/websocket.py @@ -1,7 +1,7 @@ import json import logging +from collections.abc import AsyncGenerator from contextlib import asynccontextmanager -from typing import AsyncGenerator import anyio from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream diff --git a/src/mcp/server/fastmcp/prompts/base.py b/src/mcp/server/fastmcp/prompts/base.py index 0df3d2fd..71c48724 100644 --- a/src/mcp/server/fastmcp/prompts/base.py +++ b/src/mcp/server/fastmcp/prompts/base.py @@ -2,8 +2,8 @@ import inspect import json -from collections.abc import Callable -from typing import Any, Awaitable, Literal, Sequence +from collections.abc import Awaitable, Callable, Sequence +from typing import Any, Literal import pydantic_core from pydantic import BaseModel, Field, TypeAdapter, validate_call @@ -19,7 +19,7 @@ class Message(BaseModel): role: Literal["user", "assistant"] content: CONTENT_TYPES - def __init__(self, content: str | CONTENT_TYPES, **kwargs): + def __init__(self, content: str | CONTENT_TYPES, **kwargs: Any): if isinstance(content, str): content = TextContent(type="text", text=content) super().__init__(content=content, **kwargs) @@ -30,7 +30,7 @@ class UserMessage(Message): role: Literal["user", "assistant"] = "user" - def __init__(self, content: str | CONTENT_TYPES, **kwargs): + def __init__(self, content: str | CONTENT_TYPES, **kwargs: Any): super().__init__(content=content, **kwargs) @@ -39,11 +39,13 @@ class AssistantMessage(Message): role: Literal["user", "assistant"] = "assistant" - def __init__(self, content: str | CONTENT_TYPES, **kwargs): + def __init__(self, content: str | CONTENT_TYPES, **kwargs: Any): super().__init__(content=content, **kwargs) -message_validator = TypeAdapter(UserMessage | AssistantMessage) +message_validator = TypeAdapter[UserMessage | AssistantMessage]( + UserMessage | AssistantMessage +) SyncPromptResult = ( str | Message | dict[str, Any] | Sequence[str | Message | dict[str, Any]] @@ -73,12 +75,12 @@ class Prompt(BaseModel): arguments: list[PromptArgument] | None = Field( None, description="Arguments that can be passed to the prompt" ) - fn: Callable = Field(exclude=True) + fn: Callable[..., PromptResult | Awaitable[PromptResult]] = Field(exclude=True) @classmethod def from_function( cls, - fn: Callable[..., PromptResult], + fn: Callable[..., PromptResult | Awaitable[PromptResult]], name: str | None = None, description: str | None = None, ) -> "Prompt": @@ -99,7 +101,7 @@ def from_function( parameters = TypeAdapter(fn).json_schema() # Convert parameters to PromptArguments - arguments = [] + arguments: list[PromptArgument] = [] if "properties" in parameters: for param_name, param in parameters["properties"].items(): required = param_name in parameters.get("required", []) @@ -138,25 +140,23 @@ async def render(self, arguments: dict[str, Any] | None = None) -> list[Message] result = await result # Validate messages - if not isinstance(result, (list, tuple)): + if not isinstance(result, list | tuple): result = [result] # Convert result to messages - messages = [] - for msg in result: + messages: list[Message] = [] + for msg in result: # type: ignore[reportUnknownVariableType] try: if isinstance(msg, Message): messages.append(msg) elif isinstance(msg, dict): - msg = message_validator.validate_python(msg) - messages.append(msg) + messages.append(message_validator.validate_python(msg)) elif isinstance(msg, str): - messages.append( - UserMessage(content=TextContent(type="text", text=msg)) - ) + content = TextContent(type="text", text=msg) + messages.append(UserMessage(content=content)) else: - msg = json.dumps(pydantic_core.to_jsonable_python(msg)) - messages.append(Message(role="user", content=msg)) + content = json.dumps(pydantic_core.to_jsonable_python(msg)) + messages.append(Message(role="user", content=content)) except Exception: raise ValueError( f"Could not convert prompt result to message: {msg}" diff --git a/src/mcp/server/fastmcp/resources/resource_manager.py b/src/mcp/server/fastmcp/resources/resource_manager.py index ef4af84c..d27e6ac1 100644 --- a/src/mcp/server/fastmcp/resources/resource_manager.py +++ b/src/mcp/server/fastmcp/resources/resource_manager.py @@ -1,6 +1,7 @@ """Resource manager functionality.""" -from typing import Callable +from collections.abc import Callable +from typing import Any from pydantic import AnyUrl @@ -47,7 +48,7 @@ def add_resource(self, resource: Resource) -> Resource: def add_template( self, - fn: Callable, + fn: Callable[..., Any], uri_template: str, name: str | None = None, description: str | None = None, diff --git a/src/mcp/server/fastmcp/resources/templates.py b/src/mcp/server/fastmcp/resources/templates.py index 40afaf80..a30b1825 100644 --- a/src/mcp/server/fastmcp/resources/templates.py +++ b/src/mcp/server/fastmcp/resources/templates.py @@ -1,8 +1,11 @@ """Resource template functionality.""" +from __future__ import annotations + import inspect import re -from typing import Any, Callable +from collections.abc import Callable +from typing import Any from pydantic import BaseModel, Field, TypeAdapter, validate_call @@ -20,18 +23,20 @@ class ResourceTemplate(BaseModel): mime_type: str = Field( default="text/plain", description="MIME type of the resource content" ) - fn: Callable = Field(exclude=True) - parameters: dict = Field(description="JSON schema for function parameters") + fn: Callable[..., Any] = Field(exclude=True) + parameters: dict[str, Any] = Field( + description="JSON schema for function parameters" + ) @classmethod def from_function( cls, - fn: Callable, + fn: Callable[..., Any], uri_template: str, name: str | None = None, description: str | None = None, mime_type: str | None = None, - ) -> "ResourceTemplate": + ) -> ResourceTemplate: """Create a template from a function.""" func_name = name or fn.__name__ if func_name == "": diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 9affd9be..2f807fae 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -5,13 +5,13 @@ import inspect import json import re -from collections.abc import AsyncIterator, Iterable +from collections.abc import AsyncIterator, Callable, Iterable, Sequence from contextlib import ( AbstractAsyncContextManager, asynccontextmanager, ) from itertools import chain -from typing import Any, Callable, Generic, Literal, Sequence +from typing import Any, Generic, Literal import anyio import pydantic_core @@ -20,6 +20,7 @@ from pydantic.networks import AnyUrl from pydantic_settings import BaseSettings, SettingsConfigDict from starlette.applications import Starlette +from starlette.requests import Request from starlette.routing import Mount, Route from mcp.server.fastmcp.exceptions import ResourceError @@ -88,13 +89,13 @@ class Settings(BaseSettings, Generic[LifespanResultT]): ) lifespan: ( - Callable[["FastMCP"], AbstractAsyncContextManager[LifespanResultT]] | None + Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]] | None ) = Field(None, description="Lifespan context manager") def lifespan_wrapper( app: FastMCP, - lifespan: Callable[["FastMCP"], AbstractAsyncContextManager[LifespanResultT]], + lifespan: Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]], ) -> Callable[[MCPServer[LifespanResultT]], AbstractAsyncContextManager[object]]: @asynccontextmanager async def wrap(s: MCPServer[LifespanResultT]) -> AsyncIterator[object]: @@ -179,7 +180,7 @@ async def list_tools(self) -> list[MCPTool]: for info in tools ] - def get_context(self) -> "Context[ServerSession, object]": + def get_context(self) -> Context[ServerSession, object]: """ Returns a Context object. Note that the context will only be valid during a request; outside a request, most methods will error. @@ -478,9 +479,11 @@ def sse_app(self) -> Starlette: """Return an instance of the SSE server app.""" sse = SseServerTransport("/messages/") - async def handle_sse(request): + async def handle_sse(request: Request) -> None: async with sse.connect_sse( - request.scope, request.receive, request._send + request.scope, + request.receive, + request._send, # type: ignore[reportPrivateUsage] ) as streams: await self._mcp_server.run( streams[0], @@ -535,14 +538,14 @@ def _convert_to_content( if result is None: return [] - if isinstance(result, (TextContent, ImageContent, EmbeddedResource)): + if isinstance(result, TextContent | ImageContent | EmbeddedResource): return [result] if isinstance(result, Image): return [result.to_image_content()] - if isinstance(result, (list, tuple)): - return list(chain.from_iterable(_convert_to_content(item) for item in result)) + if isinstance(result, list | tuple): + return list(chain.from_iterable(_convert_to_content(item) for item in result)) # type: ignore[reportUnknownVariableType] if not isinstance(result, str): try: diff --git a/src/mcp/server/fastmcp/tools/base.py b/src/mcp/server/fastmcp/tools/base.py index bf68dc02..e137e845 100644 --- a/src/mcp/server/fastmcp/tools/base.py +++ b/src/mcp/server/fastmcp/tools/base.py @@ -1,11 +1,11 @@ from __future__ import annotations as _annotations import inspect -from typing import TYPE_CHECKING, Any, Callable +from collections.abc import Callable +from typing import TYPE_CHECKING, Any from pydantic import BaseModel, Field -import mcp.server.fastmcp from mcp.server.fastmcp.exceptions import ToolError from mcp.server.fastmcp.utilities.func_metadata import FuncMetadata, func_metadata @@ -38,8 +38,10 @@ def from_function( name: str | None = None, description: str | None = None, context_kwarg: str | None = None, - ) -> "Tool": + ) -> Tool: """Create a Tool from a function.""" + from mcp.server.fastmcp import Context + func_name = name or fn.__name__ if func_name == "": @@ -48,11 +50,10 @@ def from_function( func_doc = description or fn.__doc__ or "" is_async = inspect.iscoroutinefunction(fn) - # Find context parameter if it exists if context_kwarg is None: sig = inspect.signature(fn) for param_name, param in sig.parameters.items(): - if param.annotation is mcp.server.fastmcp.Context: + if param.annotation is Context: context_kwarg = param_name break diff --git a/src/mcp/server/fastmcp/tools/tool_manager.py b/src/mcp/server/fastmcp/tools/tool_manager.py index 9a8bba8d..4d6ac268 100644 --- a/src/mcp/server/fastmcp/tools/tool_manager.py +++ b/src/mcp/server/fastmcp/tools/tool_manager.py @@ -32,7 +32,7 @@ def list_tools(self) -> list[Tool]: def add_tool( self, - fn: Callable, + fn: Callable[..., Any], name: str | None = None, description: str | None = None, ) -> Tool: diff --git a/src/mcp/server/fastmcp/utilities/func_metadata.py b/src/mcp/server/fastmcp/utilities/func_metadata.py index 7bcc9baf..25832620 100644 --- a/src/mcp/server/fastmcp/utilities/func_metadata.py +++ b/src/mcp/server/fastmcp/utilities/func_metadata.py @@ -80,7 +80,7 @@ def pre_parse_json(self, data: dict[str, Any]) -> dict[str, Any]: dicts (JSON objects) as JSON strings, which can be pre-parsed here. """ new_data = data.copy() # Shallow copy - for field_name, field_info in self.arg_model.model_fields.items(): + for field_name, _field_info in self.arg_model.model_fields.items(): if field_name not in data.keys(): continue if isinstance(data[field_name], str): @@ -177,7 +177,9 @@ def func_metadata( def _get_typed_annotation(annotation: Any, globalns: dict[str, Any]) -> Any: - def try_eval_type(value, globalns, localns): + def try_eval_type( + value: Any, globalns: dict[str, Any], localns: dict[str, Any] + ) -> tuple[Any, bool]: try: return eval_type_backport(value, globalns, localns), True except NameError: diff --git a/src/mcp/server/fastmcp/utilities/logging.py b/src/mcp/server/fastmcp/utilities/logging.py index df9da433..091d57e6 100644 --- a/src/mcp/server/fastmcp/utilities/logging.py +++ b/src/mcp/server/fastmcp/utilities/logging.py @@ -24,7 +24,7 @@ def configure_logging( Args: level: the log level to use """ - handlers = [] + handlers: list[logging.Handler] = [] try: from rich.console import Console from rich.logging import RichHandler diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 817d1918..e14f73e1 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -69,9 +69,9 @@ async def main(): import contextvars import logging import warnings -from collections.abc import Awaitable, Callable, Iterable +from collections.abc import AsyncIterator, Awaitable, Callable, Iterable from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager -from typing import Any, AsyncIterator, Generic, TypeVar +from typing import Any, Generic, TypeVar import anyio from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream @@ -155,9 +155,7 @@ def pkg_version(package: str) -> str: try: from importlib.metadata import version - v = version(package) - if v is not None: - return v + return version(package) except Exception: pass @@ -320,7 +318,6 @@ def create_content(data: str | bytes, mime_type: str | None): contents_list = [ create_content(content_item.content, content_item.mime_type) for content_item in contents - if isinstance(content_item, ReadResourceContents) ] return types.ServerResult( types.ReadResourceResult( @@ -511,7 +508,8 @@ async def _handle_message( raise_exceptions: bool = False, ): with warnings.catch_warnings(record=True) as w: - match message: + # TODO(Marcelo): We should be checking if message is Exception here. + match message: # type: ignore[reportMatchNotExhaustive] case ( RequestResponder(request=types.ClientRequest(root=req)) as responder ): @@ -527,7 +525,7 @@ async def _handle_message( async def _handle_request( self, - message: RequestResponder, + message: RequestResponder[types.ClientRequest, types.ServerResult], req: Any, session: ServerSession, lifespan_context: LifespanResultT, diff --git a/src/mcp/shared/memory.py b/src/mcp/shared/memory.py index 495f0c1e..938d4a30 100644 --- a/src/mcp/shared/memory.py +++ b/src/mcp/shared/memory.py @@ -2,9 +2,10 @@ In-memory transports """ +from collections.abc import AsyncGenerator from contextlib import asynccontextmanager from datetime import timedelta -from typing import AsyncGenerator +from typing import Any import anyio from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream @@ -52,7 +53,7 @@ async def create_client_server_memory_streams() -> ( @asynccontextmanager async def create_connected_server_and_client_session( - server: Server, + server: Server[Any], read_timeout_seconds: timedelta | None = None, sampling_callback: SamplingFnT | None = None, list_roots_callback: ListRootsFnT | None = None, diff --git a/src/mcp/shared/progress.py b/src/mcp/shared/progress.py index 19ea5ede..db502cf0 100644 --- a/src/mcp/shared/progress.py +++ b/src/mcp/shared/progress.py @@ -1,10 +1,19 @@ +from collections.abc import Generator from contextlib import contextmanager from dataclasses import dataclass, field +from typing import Generic from pydantic import BaseModel from mcp.shared.context import RequestContext -from mcp.shared.session import BaseSession +from mcp.shared.session import ( + BaseSession, + ReceiveNotificationT, + ReceiveRequestT, + SendNotificationT, + SendRequestT, + SendResultT, +) from mcp.types import ProgressToken @@ -14,8 +23,22 @@ class Progress(BaseModel): @dataclass -class ProgressContext: - session: BaseSession +class ProgressContext( + Generic[ + SendRequestT, + SendNotificationT, + SendResultT, + ReceiveRequestT, + ReceiveNotificationT, + ] +): + session: BaseSession[ + SendRequestT, + SendNotificationT, + SendResultT, + ReceiveRequestT, + ReceiveNotificationT, + ] progress_token: ProgressToken total: float | None current: float = field(default=0.0, init=False) @@ -29,7 +52,27 @@ async def progress(self, amount: float) -> None: @contextmanager -def progress(ctx: RequestContext, total: float | None = None): +def progress( + ctx: RequestContext[ + BaseSession[ + SendRequestT, + SendNotificationT, + SendResultT, + ReceiveRequestT, + ReceiveNotificationT, + ] + ], + total: float | None = None, +) -> Generator[ + ProgressContext[ + SendRequestT, + SendNotificationT, + SendResultT, + ReceiveRequestT, + ReceiveNotificationT, + ], + None, +]: if ctx.meta is None or ctx.meta.progressToken is None: raise ValueError("No progress token provided") diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 31f88824..31c04df3 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -1,7 +1,9 @@ import logging +from collections.abc import Callable from contextlib import AsyncExitStack from datetime import timedelta -from typing import Any, Callable, Generic, TypeVar +from types import TracebackType +from typing import Any, Generic, TypeVar import anyio import anyio.lowlevel @@ -86,7 +88,12 @@ def __enter__(self) -> "RequestResponder[ReceiveRequestT, SendResultT]": self._cancel_scope.__enter__() return self - def __exit__(self, exc_type, exc_val, exc_tb) -> None: + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: """Exit the context manager, performing cleanup and notifying completion.""" try: if self._completed: @@ -112,7 +119,7 @@ async def respond(self, response: SendResultT | ErrorData) -> None: if not self.cancelled: self._completed = True - await self._session._send_response( + await self._session._send_response( # type: ignore[reportPrivateUsage] request_id=self.request_id, response=response ) @@ -126,7 +133,7 @@ async def cancel(self) -> None: self._cancel_scope.cancel() self._completed = True # Mark as completed so it's removed from in_flight # Send an error response to indicate cancellation - await self._session._send_response( + await self._session._send_response( # type: ignore[reportPrivateUsage] request_id=self.request_id, response=ErrorData(code=0, message="Request cancelled", data=None), ) @@ -137,7 +144,7 @@ def in_flight(self) -> bool: @property def cancelled(self) -> bool: - return self._cancel_scope is not None and self._cancel_scope.cancel_called + return self._cancel_scope.cancel_called class BaseSession( @@ -202,7 +209,12 @@ async def __aenter__(self) -> Self: self._task_group.start_soon(self._receive_loop) return self - async def __aexit__(self, exc_type, exc_val, exc_tb): + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: await self._exit_stack.aclose() # Using BaseSession as a context manager should not block on exit (this # would be very surprising behavior), so make sure to cancel the tasks @@ -324,7 +336,7 @@ async def _receive_loop(self) -> None: self._in_flight[responder.request_id] = responder await self._received_request(responder) - if not responder._completed: + if not responder._completed: # type: ignore[reportPrivateUsage] await self._incoming_message_stream_writer.send(responder) elif isinstance(message.root, JSONRPCNotification): diff --git a/src/mcp/types.py b/src/mcp/types.py index 7d867bd3..f043fb10 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -1,7 +1,7 @@ +from collections.abc import Callable from typing import ( Annotated, Any, - Callable, Generic, Literal, TypeAlias, @@ -89,6 +89,7 @@ class Notification(BaseModel, Generic[NotificationParamsT, MethodT]): """Base class for JSON-RPC notifications.""" method: MethodT + params: NotificationParamsT model_config = ConfigDict(extra="allow") @@ -1010,7 +1011,9 @@ class CancelledNotificationParams(NotificationParams): model_config = ConfigDict(extra="allow") -class CancelledNotification(Notification): +class CancelledNotification( + Notification[CancelledNotificationParams, Literal["notifications/cancelled"]] +): """ This notification can be sent by either side to indicate that it is cancelling a previously-issued request. diff --git a/tests/client/test_config.py b/tests/client/test_config.py index b8371e4b..97030e06 100644 --- a/tests/client/test_config.py +++ b/tests/client/test_config.py @@ -1,5 +1,6 @@ import json import subprocess +from pathlib import Path from unittest.mock import patch import pytest @@ -8,7 +9,7 @@ @pytest.fixture -def temp_config_dir(tmp_path): +def temp_config_dir(tmp_path: Path): """Create a temporary Claude config directory.""" config_dir = tmp_path / "Claude" config_dir.mkdir() @@ -16,23 +17,20 @@ def temp_config_dir(tmp_path): @pytest.fixture -def mock_config_path(temp_config_dir): +def mock_config_path(temp_config_dir: Path): """Mock get_claude_config_path to return our temporary directory.""" with patch("mcp.cli.claude.get_claude_config_path", return_value=temp_config_dir): yield temp_config_dir -def test_command_execution(mock_config_path): +def test_command_execution(mock_config_path: Path): """Test that the generated command can actually be executed.""" # Setup server_name = "test_server" file_spec = "test_server.py:app" # Update config - success = update_claude_config( - file_spec=file_spec, - server_name=server_name, - ) + success = update_claude_config(file_spec=file_spec, server_name=server_name) assert success # Read the generated config diff --git a/tests/client/test_list_roots_callback.py b/tests/client/test_list_roots_callback.py index 384e7676..f5b59821 100644 --- a/tests/client/test_list_roots_callback.py +++ b/tests/client/test_list_roots_callback.py @@ -7,11 +7,7 @@ from mcp.shared.memory import ( create_connected_server_and_client_session as create_session, ) -from mcp.types import ( - ListRootsResult, - Root, - TextContent, -) +from mcp.types import ListRootsResult, Root, TextContent @pytest.mark.anyio @@ -39,7 +35,7 @@ async def list_roots_callback( return callback_return @server.tool("test_list_roots") - async def test_list_roots(context: Context, message: str): + async def test_list_roots(context: Context, message: str): # type: ignore[reportUnknownMemberType] roots = await context.session.list_roots() assert roots == callback_return return True diff --git a/tests/client/test_logging_callback.py b/tests/client/test_logging_callback.py index ead4f092..74f4b487 100644 --- a/tests/client/test_logging_callback.py +++ b/tests/client/test_logging_callback.py @@ -1,4 +1,4 @@ -from typing import List, Literal +from typing import Literal import anyio import pytest @@ -14,7 +14,7 @@ class LoggingCollector: def __init__(self): - self.log_messages: List[LoggingMessageNotificationParams] = [] + self.log_messages: list[LoggingMessageNotificationParams] = [] async def __call__(self, params: LoggingMessageNotificationParams) -> None: self.log_messages.append(params) diff --git a/tests/issues/test_88_random_error.py b/tests/issues/test_88_random_error.py index 8609c209..0aac6608 100644 --- a/tests/issues/test_88_random_error.py +++ b/tests/issues/test_88_random_error.py @@ -1,8 +1,8 @@ """Test to reproduce issue #88: Random error thrown on response.""" +from collections.abc import Sequence from datetime import timedelta from pathlib import Path -from typing import Sequence import anyio import pytest diff --git a/tests/server/fastmcp/test_server.py b/tests/server/fastmcp/test_server.py index 5d375ccc..e76e59c5 100644 --- a/tests/server/fastmcp/test_server.py +++ b/tests/server/fastmcp/test_server.py @@ -1,6 +1,6 @@ import base64 from pathlib import Path -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING import pytest from pydantic import AnyUrl @@ -114,7 +114,7 @@ def image_tool_fn(path: str) -> Image: return Image(path) -def mixed_content_tool_fn() -> list[Union[TextContent, ImageContent]]: +def mixed_content_tool_fn() -> list[TextContent | ImageContent]: return [ TextContent(type="text", text="Hello"), ImageContent(type="image", data="abc", mimeType="image/png"), diff --git a/tests/server/fastmcp/test_tool_manager.py b/tests/server/fastmcp/test_tool_manager.py index 4adfc47b..d2067583 100644 --- a/tests/server/fastmcp/test_tool_manager.py +++ b/tests/server/fastmcp/test_tool_manager.py @@ -1,6 +1,5 @@ import json import logging -from typing import Optional import pytest from pydantic import BaseModel @@ -296,7 +295,7 @@ async def test_context_optional(self): """Test that context is optional when calling tools.""" from mcp.server.fastmcp import Context - def tool_with_context(x: int, ctx: Optional[Context] = None) -> str: + def tool_with_context(x: int, ctx: Context | None = None) -> str: return str(x) manager = ToolManager() diff --git a/tests/server/test_lifespan.py b/tests/server/test_lifespan.py index 37a52969..309a44b8 100644 --- a/tests/server/test_lifespan.py +++ b/tests/server/test_lifespan.py @@ -1,7 +1,7 @@ """Tests for lifespan functionality in both low-level and FastMCP servers.""" +from collections.abc import AsyncIterator from contextlib import asynccontextmanager -from typing import AsyncIterator import anyio import pytest diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index 65cf061e..59cb30c8 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -1,4 +1,4 @@ -from typing import AsyncGenerator +from collections.abc import AsyncGenerator import anyio import pytest diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 87129ba9..43107b59 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -1,7 +1,7 @@ import multiprocessing import socket import time -from typing import AsyncGenerator, Generator +from collections.abc import AsyncGenerator, Generator import anyio import httpx @@ -139,7 +139,7 @@ def server(server_port: int) -> Generator[None, None, None]: attempt += 1 else: raise RuntimeError( - "Server failed to start after {} attempts".format(max_attempts) + f"Server failed to start after {max_attempts} attempts" ) yield diff --git a/tests/shared/test_ws.py b/tests/shared/test_ws.py index bdc5160a..2aca97e1 100644 --- a/tests/shared/test_ws.py +++ b/tests/shared/test_ws.py @@ -1,7 +1,7 @@ import multiprocessing import socket import time -from typing import AsyncGenerator, Generator +from collections.abc import AsyncGenerator, Generator import anyio import pytest @@ -135,7 +135,7 @@ def server(server_port: int) -> Generator[None, None, None]: attempt += 1 else: raise RuntimeError( - "Server failed to start after {} attempts".format(max_attempts) + f"Server failed to start after {max_attempts} attempts" ) yield