Skip to content

Commit 166c88d

Browse files
authored
Merge pull request #872 from PrefectHQ/tools
Tool → FunctionTool
2 parents a77dfa1 + 3443daf commit 166c88d

File tree

8 files changed

+41
-59
lines changed

8 files changed

+41
-59
lines changed

cookbook/slackbot/start.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -142,13 +142,13 @@ async def handle_message(payload: SlackPayload) -> Completed:
142142
ai_response_text,
143143
"green",
144144
)
145+
messages = await assistant_thread.get_messages_async()
146+
145147
event = emit_assistant_completed_event(
146148
child_assistant=ai,
147149
parent_app=get_parent_app() if ENABLE_PARENT_APP else None,
148150
payload={
149-
"messages": await assistant_thread.get_messages_async(
150-
json_compatible=True
151-
),
151+
"messages": [m.model_dump() for m in messages],
152152
"metadata": assistant_thread.metadata,
153153
"user": {
154154
"id": event.user,

src/marvin/_mappings/base_model.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from pydantic import BaseModel
44
from pydantic.json_schema import GenerateJsonSchema, JsonSchemaMode
55

6-
from marvin.types import Function, Tool, ToolSet
6+
from marvin.types import Function, FunctionTool, ToolSet
77

88

99
class FunctionSchema(GenerateJsonSchema):
@@ -15,10 +15,10 @@ def generate(self, schema: Any, mode: JsonSchemaMode = "validation"):
1515

1616
def cast_model_to_tool(
1717
model: type[BaseModel],
18-
) -> Tool[BaseModel]:
18+
) -> FunctionTool[BaseModel]:
1919
model_name = model.__name__
2020
model_description = model.__doc__
21-
return Tool[BaseModel](
21+
return FunctionTool[BaseModel](
2222
type="function",
2323
function=Function[BaseModel](
2424
name=model_name,

src/marvin/_mappings/types.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from pydantic.fields import FieldInfo
77

88
from marvin.settings import settings
9-
from marvin.types import Grammar, Tool, ToolSet
9+
from marvin.types import FunctionTool, Grammar, ToolSet
1010

1111
from .base_model import cast_model_to_tool, cast_model_to_toolset
1212

@@ -46,7 +46,7 @@ def cast_type_to_tool(
4646
field_name: str,
4747
field_description: str,
4848
python_function: Optional[Callable[..., Any]] = None,
49-
) -> Tool[BaseModel]:
49+
) -> FunctionTool[BaseModel]:
5050
return cast_model_to_tool(
5151
model=cast_type_to_model(
5252
_type,

src/marvin/beta/applications/state/state.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from jsonpatch import JsonPatch
66
from pydantic import BaseModel, Field, PrivateAttr, SerializeAsAny
77

8-
from marvin.types import Tool
8+
from marvin.types import FunctionTool
99
from marvin.utilities.tools import tool_from_function
1010

1111

@@ -66,7 +66,7 @@ def update_state_jsonpatches(self, patches: list[JSONPatchModel]):
6666
self.set_state(state)
6767
return "Application state updated successfully!"
6868

69-
def as_tool(self, name: str = None) -> "Tool":
69+
def as_tool(self, name: str = None) -> "FunctionTool":
7070
if name is None:
7171
name = "state"
7272
schema = self.get_schema()

src/marvin/beta/assistants/runs.py

+10-20
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
)
77
from openai.types.beta.threads.run import Run as OpenAIRun
88
from openai.types.beta.threads.runs import RunStep as OpenAIRunStep
9-
from pydantic import BaseModel, Field, PrivateAttr, field_validator
9+
from pydantic import BaseModel, Field, field_validator
1010

1111
import marvin.utilities.openai
1212
import marvin.utilities.tools
@@ -39,6 +39,7 @@ class Run(BaseModel, ExposeSyncMethodsMixin):
3939
data (Any): Any additional data associated with the run.
4040
"""
4141

42+
id: Optional[str] = None
4243
thread: Thread
4344
assistant: Assistant
4445
instructions: Optional[str] = Field(
@@ -77,15 +78,15 @@ async def refresh_async(self):
7778
"""Refreshes the run."""
7879
client = marvin.utilities.openai.get_openai_client()
7980
self.run = await client.beta.threads.runs.retrieve(
80-
run_id=self.run.id, thread_id=self.thread.id
81+
run_id=self.run.id if self.run else self.id, thread_id=self.thread.id
8182
)
8283

8384
@expose_sync_method("cancel")
8485
async def cancel_async(self):
8586
"""Cancels the run."""
8687
client = marvin.utilities.openai.get_openai_client()
8788
await client.beta.threads.runs.cancel(
88-
run_id=self.run.id, thread_id=self.thread.id
89+
run_id=self.run.id if self.run else self.id, thread_id=self.thread.id
8990
)
9091

9192
async def _handle_step_requires_action(
@@ -156,6 +157,10 @@ async def run_async(self) -> "Run":
156157
if self.tools is not None or self.additional_tools is not None:
157158
create_kwargs["tools"] = self.get_tools()
158159

160+
if self.id is not None:
161+
raise ValueError(
162+
"This run object was provided an ID; can not create a new run."
163+
)
159164
async with self.assistant:
160165
self.run = await client.beta.threads.runs.create(
161166
thread_id=self.thread.id,
@@ -195,25 +200,10 @@ async def run_async(self) -> "Run":
195200

196201

197202
class RunMonitor(BaseModel):
198-
run_id: str
199-
thread_id: str
200-
_run: Run = PrivateAttr()
201-
_thread: Thread = PrivateAttr()
203+
run: Run
204+
thread: Thread
202205
steps: list[OpenAIRunStep] = []
203206

204-
def __init__(self, **kwargs):
205-
super().__init__(**kwargs)
206-
self._thread = Thread(**kwargs["thread_id"])
207-
self._run = Run(**kwargs["run_id"], thread=self.thread)
208-
209-
@property
210-
def thread(self):
211-
return self._thread
212-
213-
@property
214-
def run(self):
215-
return self._run
216-
217207
async def refresh_run_steps_async(self):
218208
"""
219209
Asynchronously refreshes and updates the run steps list.

src/marvin/beta/assistants/threads.py

+8-19
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import asyncio
22
import time
3-
from typing import TYPE_CHECKING, Callable, Optional, Union
3+
from typing import TYPE_CHECKING, Callable, Optional
44

55
# for openai < 1.14.0
66
try:
@@ -18,7 +18,6 @@
1818
run_sync,
1919
)
2020
from marvin.utilities.logging import get_logger
21-
from marvin.utilities.pydantic import parse_as
2221

2322
logger = get_logger("Threads")
2423

@@ -100,25 +99,18 @@ async def get_messages_async(
10099
limit: int = None,
101100
before_message: Optional[str] = None,
102101
after_message: Optional[str] = None,
103-
json_compatible: bool = False,
104-
) -> list[Union[Message, dict]]:
102+
) -> list[Message]:
105103
"""
106104
Asynchronously retrieves messages from the thread.
107105
108106
Args:
109107
limit (int, optional): The maximum number of messages to return.
110-
before_message (str, optional): The ID of the message to start the list from,
111-
retrieving messages sent before this one.
112-
after_message (str, optional): The ID of the message to start the list from,
113-
retrieving messages sent after this one.
114-
json_compatible (bool, optional): If True, returns messages as dictionaries.
115-
If False, returns messages as Message
116-
objects. Default is False.
117-
108+
before_message (str, optional): The ID of the message to start the
109+
list from, retrieving messages sent before this one.
110+
after_message (str, optional): The ID of the message to start the
111+
list from, retrieving messages sent after this one.
118112
Returns:
119-
list[Union[Message, dict]]: A list of messages from the thread, either
120-
as dictionaries or Message objects,
121-
depending on the value of json_compatible.
113+
list[Union[Message, dict]]: A list of messages from the thread
122114
"""
123115

124116
if self.id is None:
@@ -134,10 +126,7 @@ async def get_messages_async(
134126
limit=limit,
135127
order="desc",
136128
)
137-
138-
T = dict if json_compatible else Message
139-
140-
return parse_as(list[T], reversed(response.model_dump()["data"]))
129+
return response.data
141130

142131
@expose_sync_method("delete")
143132
async def delete_async(self):

src/marvin/types.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -60,21 +60,24 @@ def create(
6060
return instance
6161

6262

63-
class Tool(MarvinType, Generic[T]):
63+
class Tool(MarvinType):
6464
type: str
65+
66+
67+
class FunctionTool(Tool, Generic[T]):
6568
function: Optional[Function[T]] = None
6669

6770

6871
class ToolSet(MarvinType, Generic[T]):
69-
tools: Optional[list[Tool[T]]] = None
72+
tools: Optional[list[Union[FunctionTool[T], Tool]]] = None
7073
tool_choice: Optional[Union[Literal["auto"], dict[str, Any]]] = None
7174

7275

73-
class RetrievalTool(Tool[T]):
76+
class RetrievalTool(Tool):
7477
type: Literal["retrieval"] = "retrieval"
7578

7679

77-
class CodeInterpreterTool(Tool[T]):
80+
class CodeInterpreterTool(Tool):
7881
type: Literal["code_interpreter"] = "code_interpreter"
7982

8083

@@ -244,7 +247,7 @@ class Run(MarvinType, Generic[T]):
244247
status: str
245248
model: str
246249
instructions: Optional[str]
247-
tools: Optional[list[Tool[T]]] = None
250+
tools: Optional[list[FunctionTool[T]]] = None
248251
metadata: dict[str, str]
249252

250253

src/marvin/utilities/tools.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from pydantic.fields import FieldInfo
1818
from pydantic.json_schema import GenerateJsonSchema, JsonSchemaMode
1919

20-
from marvin.types import Function, Tool
20+
from marvin.types import Function, FunctionTool
2121
from marvin.utilities.asyncio import run_sync
2222
from marvin.utilities.logging import get_logger
2323

@@ -63,7 +63,7 @@ def generate(self, schema: Any, mode: JsonSchemaMode = "validation"):
6363
return json_schema
6464

6565

66-
def tool_from_type(type_: U, tool_name: str = None) -> Tool[U]:
66+
def tool_from_type(type_: U, tool_name: str = None) -> FunctionTool[U]:
6767
"""
6868
Creates an OpenAI-compatible tool from a Python type.
6969
"""
@@ -99,7 +99,7 @@ def tool_from_model(model: type[M], python_fn: Callable[[str], M] = None):
9999
def tool_fn(**data) -> M:
100100
return TypeAdapter(model).validate_python(data)
101101

102-
return Tool[M](
102+
return FunctionTool[M](
103103
type="function",
104104
function=Function[M].create(
105105
name=model.__name__,
@@ -130,7 +130,7 @@ def tool_from_function(
130130
fn, config=pydantic.ConfigDict(arbitrary_types_allowed=True)
131131
).json_schema()
132132

133-
return Tool[T](
133+
return FunctionTool[T](
134134
type="function",
135135
function=Function[T].create(
136136
name=name or fn.__name__,
@@ -142,7 +142,7 @@ def tool_from_function(
142142

143143

144144
def call_function_tool(
145-
tools: list[Tool],
145+
tools: list[FunctionTool],
146146
function_name: str,
147147
function_arguments_json: str,
148148
return_string: bool = False,

0 commit comments

Comments
 (0)