|
1 | 1 | """Base classes for FastMCP prompts."""
|
2 | 2 |
|
3 | 3 | import json
|
4 |
| -from typing import Any, Callable, Dict, Literal, Optional, Sequence, Union |
| 4 | +from typing import Any, Callable, Dict, Literal, Optional, Sequence, Awaitable |
5 | 5 | import inspect
|
6 | 6 |
|
7 |
| -from pydantic import BaseModel, Field, TypeAdapter, field_validator, validate_call |
| 7 | +from pydantic import BaseModel, Field, TypeAdapter, validate_call |
8 | 8 | from mcp.types import TextContent, ImageContent, EmbeddedResource
|
9 | 9 | import pydantic_core
|
10 | 10 |
|
| 11 | +CONTENT_TYPES = TextContent | ImageContent | EmbeddedResource |
| 12 | + |
11 | 13 |
|
12 | 14 | class Message(BaseModel):
|
13 | 15 | """Base class for all prompt messages."""
|
14 | 16 |
|
15 | 17 | role: Literal["user", "assistant"]
|
16 |
| - content: Union[TextContent, ImageContent, EmbeddedResource] |
| 18 | + content: CONTENT_TYPES |
17 | 19 |
|
18 |
| - def __init__(self, content, **kwargs): |
| 20 | + def __init__(self, content: str | CONTENT_TYPES, **kwargs): |
| 21 | + if isinstance(content, str): |
| 22 | + content = TextContent(type="text", text=content) |
19 | 23 | super().__init__(content=content, **kwargs)
|
20 | 24 |
|
21 |
| - @field_validator("content", mode="before") |
22 |
| - def validate_content(cls, v): |
23 |
| - if isinstance(v, str): |
24 |
| - return TextContent(type="text", text=v) |
25 |
| - return v |
26 |
| - |
27 | 25 |
|
28 | 26 | class UserMessage(Message):
|
29 | 27 | """A message from the user."""
|
30 | 28 |
|
31 | 29 | role: Literal["user"] = "user"
|
32 | 30 |
|
| 31 | + def __init__(self, content: str | CONTENT_TYPES, **kwargs): |
| 32 | + super().__init__(content=content, **kwargs) |
| 33 | + |
33 | 34 |
|
34 | 35 | class AssistantMessage(Message):
|
35 | 36 | """A message from the assistant."""
|
36 | 37 |
|
37 | 38 | role: Literal["assistant"] = "assistant"
|
38 | 39 |
|
| 40 | + def __init__(self, content: str | CONTENT_TYPES, **kwargs): |
| 41 | + super().__init__(content=content, **kwargs) |
| 42 | + |
| 43 | + |
| 44 | +message_validator = TypeAdapter(UserMessage | AssistantMessage) |
39 | 45 |
|
40 |
| -message_validator = TypeAdapter(Union[UserMessage, AssistantMessage]) |
| 46 | +SyncPromptResult = ( |
| 47 | + str | Message | dict[str, Any] | Sequence[str | Message | dict[str, Any]] |
| 48 | +) |
| 49 | +PromptResult = SyncPromptResult | Awaitable[SyncPromptResult] |
41 | 50 |
|
42 | 51 |
|
43 | 52 | class PromptArgument(BaseModel):
|
@@ -67,11 +76,18 @@ class Prompt(BaseModel):
|
67 | 76 | @classmethod
|
68 | 77 | def from_function(
|
69 | 78 | cls,
|
70 |
| - fn: Callable[..., Sequence[Message]], |
| 79 | + fn: Callable[..., PromptResult], |
71 | 80 | name: Optional[str] = None,
|
72 | 81 | description: Optional[str] = None,
|
73 | 82 | ) -> "Prompt":
|
74 |
| - """Create a Prompt from a function.""" |
| 83 | + """Create a Prompt from a function. |
| 84 | +
|
| 85 | + The function can return: |
| 86 | + - A string (converted to a message) |
| 87 | + - A Message object |
| 88 | + - A dict (converted to a message) |
| 89 | + - A sequence of any of the above |
| 90 | + """ |
75 | 91 | func_name = name or fn.__name__
|
76 | 92 |
|
77 | 93 | if func_name == "<lambda>":
|
|
0 commit comments