Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pyright + tests #52

Merged
merged 8 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,9 @@ jobs:
python-version: "3.12"
- name: Run pre-commit
uses: pre-commit/[email protected]
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install ".[tests]"
- name: Run pyright
run: pyright src tests
13 changes: 13 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ build-backend = "hatchling.build"
[project.optional-dependencies]
tests = [
"pre-commit",
"pyright>=1.1.389",
"pytest>=8.3.3",
"pytest-asyncio>=0.23.5",
"pytest-flakefinder",
Expand All @@ -39,3 +40,15 @@ asyncio_default_fixture_loop_scope = "session"

[tool.hatch.version]
source = "vcs"

[tool.pyright]
include = ["src", "tests"]
exclude = ["**/node_modules", "**/__pycache__", ".venv", ".git", "dist"]
pythonVersion = "3.10"
pythonPlatform = "Darwin"
typeCheckingMode = "basic"
reportMissingImports = true
reportMissingTypeStubs = false
useLibraryCodeForTypes = true
venvPath = "."
venv = ".venv"
7 changes: 6 additions & 1 deletion src/fastmcp/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import importlib.metadata
import importlib.util
import os
import subprocess
import sys
from pathlib import Path
Expand Down Expand Up @@ -242,6 +243,7 @@ def dev(
[npx_cmd, "@modelcontextprotocol/inspector"] + uv_cmd,
check=True,
shell=shell,
env=dict(os.environ.items()), # Convert to list of tuples for env update
)
sys.exit(process.returncode)
except subprocess.CalledProcessError as e:
Expand Down Expand Up @@ -423,7 +425,10 @@ def install(
# Load from .env file if specified
if env_file:
try:
env_dict.update(dotenv.dotenv_values(env_file))
env_values = dotenv.dotenv_values(env_file)
env_dict.update(
(k, str(v)) for k, v in env_values.items() if v is not None
)
except Exception as e:
logger.error(f"Failed to load .env file: {e}")
sys.exit(1)
Expand Down
42 changes: 29 additions & 13 deletions src/fastmcp/prompts/base.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,52 @@
"""Base classes for FastMCP prompts."""

import json
from typing import Any, Callable, Dict, Literal, Optional, Sequence, Union
from typing import Any, Callable, Dict, Literal, Optional, Sequence, Awaitable
import inspect

from pydantic import BaseModel, Field, TypeAdapter, field_validator, validate_call
from pydantic import BaseModel, Field, TypeAdapter, validate_call
from mcp.types import TextContent, ImageContent, EmbeddedResource
import pydantic_core

CONTENT_TYPES = TextContent | ImageContent | EmbeddedResource


class Message(BaseModel):
"""Base class for all prompt messages."""

role: Literal["user", "assistant"]
content: Union[TextContent, ImageContent, EmbeddedResource]
content: CONTENT_TYPES

def __init__(self, content, **kwargs):
def __init__(self, content: str | CONTENT_TYPES, **kwargs):
if isinstance(content, str):
content = TextContent(type="text", text=content)
super().__init__(content=content, **kwargs)

@field_validator("content", mode="before")
def validate_content(cls, v):
if isinstance(v, str):
return TextContent(type="text", text=v)
return v


class UserMessage(Message):
"""A message from the user."""

role: Literal["user"] = "user"

def __init__(self, content: str | CONTENT_TYPES, **kwargs):
super().__init__(content=content, **kwargs)


class AssistantMessage(Message):
"""A message from the assistant."""

role: Literal["assistant"] = "assistant"

def __init__(self, content: str | CONTENT_TYPES, **kwargs):
super().__init__(content=content, **kwargs)


message_validator = TypeAdapter(UserMessage | AssistantMessage)

message_validator = TypeAdapter(Union[UserMessage, AssistantMessage])
SyncPromptResult = (
str | Message | dict[str, Any] | Sequence[str | Message | dict[str, Any]]
)
PromptResult = SyncPromptResult | Awaitable[SyncPromptResult]


class PromptArgument(BaseModel):
Expand Down Expand Up @@ -67,11 +76,18 @@ class Prompt(BaseModel):
@classmethod
def from_function(
cls,
fn: Callable[..., Sequence[Message]],
fn: Callable[..., PromptResult],
name: Optional[str] = None,
description: Optional[str] = None,
) -> "Prompt":
"""Create a Prompt from a function."""
"""Create a Prompt from a function.

The function can return:
- A string (converted to a message)
- A Message object
- A dict (converted to a message)
- A sequence of any of the above
"""
func_name = name or fn.__name__

if func_name == "<lambda>":
Expand Down
18 changes: 5 additions & 13 deletions src/fastmcp/resources/base.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
"""Base classes and interfaces for FastMCP resources."""

import abc
from typing import Union
from typing import Union, Annotated

from pydantic import (
AnyUrl,
BaseModel,
ConfigDict,
Field,
FileUrl,
UrlConstraints,
ValidationInfo,
field_validator,
)
Expand All @@ -19,8 +19,9 @@ class Resource(BaseModel, abc.ABC):

model_config = ConfigDict(validate_default=True)

# uri: Annotated[AnyUrl, BeforeValidator(maybe_cast_str_to_any_url)] = Field(
uri: AnyUrl = Field(default=..., description="URI of the resource")
uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] = Field(
default=..., description="URI of the resource"
)
name: str | None = Field(description="Name of the resource", default=None)
description: str | None = Field(
description="Description of the resource", default=None
Expand All @@ -31,15 +32,6 @@ class Resource(BaseModel, abc.ABC):
pattern=r"^[a-zA-Z0-9]+/[a-zA-Z0-9\-+.]+$",
)

@field_validator("uri", mode="before")
def validate_uri(cls, uri: AnyUrl | str) -> AnyUrl:
if isinstance(uri, str):
# AnyUrl doesn't support triple-slashes, but files do ("file:///absolute/path")
if uri.startswith("file://"):
return FileUrl(uri)
return AnyUrl(uri)
return uri

@field_validator("name", mode="before")
@classmethod
def set_default_name(cls, name: str | None, info: ValidationInfo) -> str:
Expand Down
2 changes: 1 addition & 1 deletion src/fastmcp/resources/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ async def create_resource(self, uri: str, params: Dict[str, Any]) -> Resource:
result = await result

return FunctionResource(
uri=uri,
uri=uri, # type: ignore
name=self.name,
description=self.description,
mime_type=self.mime_type,
Expand Down
11 changes: 10 additions & 1 deletion src/fastmcp/resources/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import httpx
import pydantic.json
import pydantic_core
from pydantic import Field
from pydantic import Field, ValidationInfo

from fastmcp.resources.base import Resource

Expand Down Expand Up @@ -91,6 +91,15 @@ def validate_absolute_path(cls, path: Path) -> Path:
raise ValueError("Path must be absolute")
return path

@pydantic.field_validator("is_binary")
@classmethod
def set_binary_from_mime_type(cls, is_binary: bool, info: ValidationInfo) -> bool:
"""Set is_binary based on mime_type if not explicitly set."""
if is_binary:
return True
mime_type = info.data.get("mime_type", "text/plain")
return not mime_type.startswith("text/")

async def read(self) -> Union[str, bytes]:
"""Read the file content."""
try:
Expand Down
13 changes: 7 additions & 6 deletions src/fastmcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
)
from mcp.types import (
Prompt as MCPPrompt,
PromptArgument as MCPPromptArgument,
)
from mcp.types import (
Resource as MCPResource,
Expand Down Expand Up @@ -159,7 +160,7 @@ def get_context(self) -> "Context":

async def call_tool(
self, name: str, arguments: dict
) -> Sequence[TextContent | ImageContent]:
) -> Sequence[TextContent | ImageContent | EmbeddedResource]:
"""Call a tool by name with arguments."""
context = self.get_context()
result = await self._tool_manager.call_tool(name, arguments, context=context)
Expand Down Expand Up @@ -462,11 +463,11 @@ async def list_prompts(self) -> list[MCPPrompt]:
name=prompt.name,
description=prompt.description,
arguments=[
{
"name": arg.name,
"description": arg.description,
"required": arg.required,
}
MCPPromptArgument(
name=arg.name,
description=arg.description,
required=arg.required,
)
for arg in (prompt.arguments or [])
],
)
Expand Down
11 changes: 8 additions & 3 deletions src/fastmcp/utilities/func_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class FuncMetadata(BaseModel):

async def call_fn_with_arg_validation(
self,
fn: Callable | Awaitable,
fn: Callable[..., Any] | Awaitable[Any],
fn_is_async: bool,
arguments_to_validate: dict[str, Any],
arguments_to_pass_directly: dict[str, Any] | None,
Expand All @@ -64,8 +64,12 @@ async def call_fn_with_arg_validation(
arguments_parsed_dict |= arguments_to_pass_directly or {}

if fn_is_async:
if isinstance(fn, Awaitable):
return await fn
return await fn(**arguments_parsed_dict)
return fn(**arguments_parsed_dict)
if isinstance(fn, Callable):
return fn(**arguments_parsed_dict)
raise TypeError("fn must be either Callable or Awaitable")

def pre_parse_json(self, data: dict[str, Any]) -> dict[str, Any]:
"""Pre-parse data from JSON.
Expand Down Expand Up @@ -123,6 +127,7 @@ def func_metadata(func: Callable, skip_names: Sequence[str] = ()) -> FuncMetadat
sig = _get_typed_signature(func)
params = sig.parameters
dynamic_pydantic_model_params: dict[str, Any] = {}
globalns = getattr(func, "__globals__", {})
for param in params.values():
if param.name.startswith("_"):
raise InvalidSignature(
Expand Down Expand Up @@ -153,7 +158,7 @@ def func_metadata(func: Callable, skip_names: Sequence[str] = ()) -> FuncMetadat
]

field_info = FieldInfo.from_annotated_attribute(
annotation,
_get_typed_annotation(annotation, globalns),
param.default
if param.default is not inspect.Parameter.empty
else PydanticUndefined,
Expand Down
4 changes: 3 additions & 1 deletion src/fastmcp/utilities/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ def to_image_content(self) -> ImageContent:
if self.path:
with open(self.path, "rb") as f:
data = base64.b64encode(f.read()).decode()
else:
elif self.data is not None:
data = base64.b64encode(self.data).decode()
else:
raise ValueError("No image data available")

return ImageContent(type="image", data=data, mimeType=self._mime_type)
13 changes: 7 additions & 6 deletions tests/prompts/test_base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from pydantic import FileUrl
import pytest
from fastmcp.prompts.base import (
Prompt,
Expand Down Expand Up @@ -102,7 +103,7 @@ async def fn() -> UserMessage:
content=EmbeddedResource(
type="resource",
resource=TextResourceContents(
uri="file://file.txt",
uri=FileUrl("file://file.txt"),
text="File contents",
mimeType="text/plain",
),
Expand All @@ -115,7 +116,7 @@ async def fn() -> UserMessage:
content=EmbeddedResource(
type="resource",
resource=TextResourceContents(
uri="file://file.txt",
uri=FileUrl("file://file.txt"),
text="File contents",
mimeType="text/plain",
),
Expand All @@ -133,7 +134,7 @@ async def fn() -> list[Message]:
content=EmbeddedResource(
type="resource",
resource=TextResourceContents(
uri="file://file.txt",
uri=FileUrl("file://file.txt"),
text="File contents",
mimeType="text/plain",
),
Expand All @@ -151,7 +152,7 @@ async def fn() -> list[Message]:
content=EmbeddedResource(
type="resource",
resource=TextResourceContents(
uri="file://file.txt",
uri=FileUrl("file://file.txt"),
text="File contents",
mimeType="text/plain",
),
Expand All @@ -171,7 +172,7 @@ async def fn() -> dict:
"content": {
"type": "resource",
"resource": {
"uri": "file://file.txt",
"uri": FileUrl("file://file.txt"),
"text": "File contents",
"mimeType": "text/plain",
},
Expand All @@ -184,7 +185,7 @@ async def fn() -> dict:
content=EmbeddedResource(
type="resource",
resource=TextResourceContents(
uri="file://file.txt",
uri=FileUrl("file://file.txt"),
text="File contents",
mimeType="text/plain",
),
Expand Down
Loading
Loading