2
2
3
3
import inspect
4
4
import json
5
- from collections .abc import Callable
6
- from typing import Any , Awaitable , Literal , Sequence
5
+ from collections .abc import Awaitable , Callable , Sequence
6
+ from typing import Any , Literal
7
7
8
8
import pydantic_core
9
9
from pydantic import BaseModel , Field , TypeAdapter , validate_call
@@ -19,7 +19,7 @@ class Message(BaseModel):
19
19
role : Literal ["user" , "assistant" ]
20
20
content : CONTENT_TYPES
21
21
22
- def __init__ (self , content : str | CONTENT_TYPES , ** kwargs ):
22
+ def __init__ (self , content : str | CONTENT_TYPES , ** kwargs : Any ):
23
23
if isinstance (content , str ):
24
24
content = TextContent (type = "text" , text = content )
25
25
super ().__init__ (content = content , ** kwargs )
@@ -30,7 +30,7 @@ class UserMessage(Message):
30
30
31
31
role : Literal ["user" , "assistant" ] = "user"
32
32
33
- def __init__ (self , content : str | CONTENT_TYPES , ** kwargs ):
33
+ def __init__ (self , content : str | CONTENT_TYPES , ** kwargs : Any ):
34
34
super ().__init__ (content = content , ** kwargs )
35
35
36
36
@@ -39,11 +39,13 @@ class AssistantMessage(Message):
39
39
40
40
role : Literal ["user" , "assistant" ] = "assistant"
41
41
42
- def __init__ (self , content : str | CONTENT_TYPES , ** kwargs ):
42
+ def __init__ (self , content : str | CONTENT_TYPES , ** kwargs : Any ):
43
43
super ().__init__ (content = content , ** kwargs )
44
44
45
45
46
- message_validator = TypeAdapter (UserMessage | AssistantMessage )
46
+ message_validator = TypeAdapter [UserMessage | AssistantMessage ](
47
+ UserMessage | AssistantMessage
48
+ )
47
49
48
50
SyncPromptResult = (
49
51
str | Message | dict [str , Any ] | Sequence [str | Message | dict [str , Any ]]
@@ -73,12 +75,12 @@ class Prompt(BaseModel):
73
75
arguments : list [PromptArgument ] | None = Field (
74
76
None , description = "Arguments that can be passed to the prompt"
75
77
)
76
- fn : Callable = Field (exclude = True )
78
+ fn : Callable [..., PromptResult | Awaitable [ PromptResult ]] = Field (exclude = True )
77
79
78
80
@classmethod
79
81
def from_function (
80
82
cls ,
81
- fn : Callable [..., PromptResult ],
83
+ fn : Callable [..., PromptResult | Awaitable [ PromptResult ] ],
82
84
name : str | None = None ,
83
85
description : str | None = None ,
84
86
) -> "Prompt" :
@@ -99,7 +101,7 @@ def from_function(
99
101
parameters = TypeAdapter (fn ).json_schema ()
100
102
101
103
# Convert parameters to PromptArguments
102
- arguments = []
104
+ arguments : list [ PromptArgument ] = []
103
105
if "properties" in parameters :
104
106
for param_name , param in parameters ["properties" ].items ():
105
107
required = param_name in parameters .get ("required" , [])
@@ -138,25 +140,23 @@ async def render(self, arguments: dict[str, Any] | None = None) -> list[Message]
138
140
result = await result
139
141
140
142
# Validate messages
141
- if not isinstance (result , ( list , tuple ) ):
143
+ if not isinstance (result , list | tuple ):
142
144
result = [result ]
143
145
144
146
# Convert result to messages
145
- messages = []
146
- for msg in result :
147
+ messages : list [ Message ] = []
148
+ for msg in result : # type: ignore[reportUnknownVariableType]
147
149
try :
148
150
if isinstance (msg , Message ):
149
151
messages .append (msg )
150
152
elif isinstance (msg , dict ):
151
- msg = message_validator .validate_python (msg )
152
- messages .append (msg )
153
+ messages .append (message_validator .validate_python (msg ))
153
154
elif isinstance (msg , str ):
154
- messages .append (
155
- UserMessage (content = TextContent (type = "text" , text = msg ))
156
- )
155
+ content = TextContent (type = "text" , text = msg )
156
+ messages .append (UserMessage (content = content ))
157
157
else :
158
- msg = json .dumps (pydantic_core .to_jsonable_python (msg ))
159
- messages .append (Message (role = "user" , content = msg ))
158
+ content = json .dumps (pydantic_core .to_jsonable_python (msg ))
159
+ messages .append (Message (role = "user" , content = content ))
160
160
except Exception :
161
161
raise ValueError (
162
162
f"Could not convert prompt result to message: { msg } "
0 commit comments