Skip to content

Commit f786d4f

Browse files
authored
Merge pull request #835 from salman1993/main
Access to tool calls and tool outputs in post_run_hook
2 parents 8ee6b7c + ecda861 commit f786d4f

File tree

2 files changed

+29
-5
lines changed

2 files changed

+29
-5
lines changed

src/marvin/beta/assistants/assistants.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
from typing import TYPE_CHECKING, Callable, Optional, Union
22

3+
from openai.types.beta.threads.required_action_function_tool_call import (
4+
RequiredActionFunctionToolCall,
5+
)
36
from pydantic import BaseModel, Field, PrivateAttr
47

58
import marvin.utilities.tools
@@ -168,5 +171,10 @@ def chat(self, thread: Thread = None):
168171
def pre_run_hook(self, run: "Run"):
169172
pass
170173

171-
def post_run_hook(self, run: "Run"):
174+
def post_run_hook(
175+
self,
176+
run: "Run",
177+
tool_calls: Optional[list[RequiredActionFunctionToolCall]] = None,
178+
tool_outputs: Optional[list[dict[str, str]]] = None,
179+
):
172180
pass

src/marvin/beta/assistants/runs.py

+20-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import asyncio
22
from typing import Any, Callable, Optional, Union
33

4+
from openai.types.beta.threads.required_action_function_tool_call import (
5+
RequiredActionFunctionToolCall,
6+
)
47
from openai.types.beta.threads.run import Run as OpenAIRun
58
from openai.types.beta.threads.runs import RunStep as OpenAIRunStep
69
from pydantic import BaseModel, Field, PrivateAttr, field_validator
@@ -85,11 +88,14 @@ async def cancel_async(self):
8588
run_id=self.run.id, thread_id=self.thread.id
8689
)
8790

88-
async def _handle_step_requires_action(self):
91+
async def _handle_step_requires_action(
92+
self,
93+
) -> tuple[list[RequiredActionFunctionToolCall], list[dict[str, str]]]:
8994
client = get_openai_client()
9095
if self.run.status != "requires_action":
91-
return
96+
return None, None
9297
if self.run.required_action.type == "submit_tool_outputs":
98+
tool_calls = []
9399
tool_outputs = []
94100
tools = self.get_tools()
95101

@@ -110,10 +116,12 @@ async def _handle_step_requires_action(self):
110116
tool_outputs.append(
111117
dict(tool_call_id=tool_call.id, output=output or "")
112118
)
119+
tool_calls.append(tool_call)
113120

114121
await client.beta.threads.runs.submit_tool_outputs(
115122
thread_id=self.thread.id, run_id=self.run.id, tool_outputs=tool_outputs
116123
)
124+
return tool_calls, tool_outputs
117125

118126
def get_instructions(self) -> str:
119127
if self.instructions is None:
@@ -157,10 +165,16 @@ async def run_async(self) -> "Run":
157165

158166
self.assistant.pre_run_hook(run=self)
159167

168+
tool_calls = None
169+
tool_outputs = None
170+
160171
try:
161172
while self.run.status in ("queued", "in_progress", "requires_action"):
162173
if self.run.status == "requires_action":
163-
await self._handle_step_requires_action()
174+
(
175+
tool_calls,
176+
tool_outputs,
177+
) = await self._handle_step_requires_action()
164178
await asyncio.sleep(0.1)
165179
await self.refresh_async()
166180
except CancelRun as exc:
@@ -174,7 +188,9 @@ async def run_async(self) -> "Run":
174188
if self.run.status == "failed":
175189
logger.debug(f"Run failed. Last error was: {self.run.last_error}")
176190

177-
self.assistant.post_run_hook(run=self)
191+
self.assistant.post_run_hook(
192+
run=self, tool_calls=tool_calls, tool_outputs=tool_outputs
193+
)
178194
return self
179195

180196

0 commit comments

Comments
 (0)