1
1
import asyncio
2
2
from typing import Any , Callable , Optional , Union
3
3
4
+ from openai .types .beta .threads .required_action_function_tool_call import (
5
+ RequiredActionFunctionToolCall ,
6
+ )
4
7
from openai .types .beta .threads .run import Run as OpenAIRun
5
8
from openai .types .beta .threads .runs import RunStep as OpenAIRunStep
6
9
from pydantic import BaseModel , Field , PrivateAttr , field_validator
@@ -85,11 +88,14 @@ async def cancel_async(self):
85
88
run_id = self .run .id , thread_id = self .thread .id
86
89
)
87
90
88
- async def _handle_step_requires_action (self ):
91
+ async def _handle_step_requires_action (
92
+ self ,
93
+ ) -> tuple [list [RequiredActionFunctionToolCall ], list [dict [str , str ]]]:
89
94
client = get_openai_client ()
90
95
if self .run .status != "requires_action" :
91
- return
96
+ return None , None
92
97
if self .run .required_action .type == "submit_tool_outputs" :
98
+ tool_calls = []
93
99
tool_outputs = []
94
100
tools = self .get_tools ()
95
101
@@ -110,10 +116,12 @@ async def _handle_step_requires_action(self):
110
116
tool_outputs .append (
111
117
dict (tool_call_id = tool_call .id , output = output or "" )
112
118
)
119
+ tool_calls .append (tool_call )
113
120
114
121
await client .beta .threads .runs .submit_tool_outputs (
115
122
thread_id = self .thread .id , run_id = self .run .id , tool_outputs = tool_outputs
116
123
)
124
+ return tool_calls , tool_outputs
117
125
118
126
def get_instructions (self ) -> str :
119
127
if self .instructions is None :
@@ -157,10 +165,16 @@ async def run_async(self) -> "Run":
157
165
158
166
self .assistant .pre_run_hook (run = self )
159
167
168
+ tool_calls = None
169
+ tool_outputs = None
170
+
160
171
try :
161
172
while self .run .status in ("queued" , "in_progress" , "requires_action" ):
162
173
if self .run .status == "requires_action" :
163
- await self ._handle_step_requires_action ()
174
+ (
175
+ tool_calls ,
176
+ tool_outputs ,
177
+ ) = await self ._handle_step_requires_action ()
164
178
await asyncio .sleep (0.1 )
165
179
await self .refresh_async ()
166
180
except CancelRun as exc :
@@ -174,7 +188,9 @@ async def run_async(self) -> "Run":
174
188
if self .run .status == "failed" :
175
189
logger .debug (f"Run failed. Last error was: { self .run .last_error } " )
176
190
177
- self .assistant .post_run_hook (run = self )
191
+ self .assistant .post_run_hook (
192
+ run = self , tool_calls = tool_calls , tool_outputs = tool_outputs
193
+ )
178
194
return self
179
195
180
196
0 commit comments