|
6 | 6 | )
|
7 | 7 | from openai.types.beta.threads.run import Run as OpenAIRun
|
8 | 8 | from openai.types.beta.threads.runs import RunStep as OpenAIRunStep
|
9 |
| -from pydantic import BaseModel, Field, PrivateAttr, field_validator |
| 9 | +from pydantic import BaseModel, Field, field_validator |
10 | 10 |
|
11 | 11 | import marvin.utilities.openai
|
12 | 12 | import marvin.utilities.tools
|
@@ -39,6 +39,7 @@ class Run(BaseModel, ExposeSyncMethodsMixin):
|
39 | 39 | data (Any): Any additional data associated with the run.
|
40 | 40 | """
|
41 | 41 |
|
| 42 | + id: Optional[str] = None |
42 | 43 | thread: Thread
|
43 | 44 | assistant: Assistant
|
44 | 45 | instructions: Optional[str] = Field(
|
@@ -77,15 +78,15 @@ async def refresh_async(self):
|
77 | 78 | """Refreshes the run."""
|
78 | 79 | client = marvin.utilities.openai.get_openai_client()
|
79 | 80 | self.run = await client.beta.threads.runs.retrieve(
|
80 |
| - run_id=self.run.id, thread_id=self.thread.id |
| 81 | + run_id=self.run.id if self.run else self.id, thread_id=self.thread.id |
81 | 82 | )
|
82 | 83 |
|
83 | 84 | @expose_sync_method("cancel")
|
84 | 85 | async def cancel_async(self):
|
85 | 86 | """Cancels the run."""
|
86 | 87 | client = marvin.utilities.openai.get_openai_client()
|
87 | 88 | await client.beta.threads.runs.cancel(
|
88 |
| - run_id=self.run.id, thread_id=self.thread.id |
| 89 | + run_id=self.run.id if self.run else self.id, thread_id=self.thread.id |
89 | 90 | )
|
90 | 91 |
|
91 | 92 | async def _handle_step_requires_action(
|
@@ -156,6 +157,10 @@ async def run_async(self) -> "Run":
|
156 | 157 | if self.tools is not None or self.additional_tools is not None:
|
157 | 158 | create_kwargs["tools"] = self.get_tools()
|
158 | 159 |
|
| 160 | + if self.id is not None: |
| 161 | + raise ValueError( |
| 162 | + "This run object was provided an ID; can not create a new run." |
| 163 | + ) |
159 | 164 | async with self.assistant:
|
160 | 165 | self.run = await client.beta.threads.runs.create(
|
161 | 166 | thread_id=self.thread.id,
|
@@ -195,25 +200,10 @@ async def run_async(self) -> "Run":
|
195 | 200 |
|
196 | 201 |
|
197 | 202 | class RunMonitor(BaseModel):
|
198 |
| - run_id: str |
199 |
| - thread_id: str |
200 |
| - _run: Run = PrivateAttr() |
201 |
| - _thread: Thread = PrivateAttr() |
| 203 | + run: Run |
| 204 | + thread: Thread |
202 | 205 | steps: list[OpenAIRunStep] = []
|
203 | 206 |
|
204 |
| - def __init__(self, **kwargs): |
205 |
| - super().__init__(**kwargs) |
206 |
| - self._thread = Thread(**kwargs["thread_id"]) |
207 |
| - self._run = Run(**kwargs["run_id"], thread=self.thread) |
208 |
| - |
209 |
| - @property |
210 |
| - def thread(self): |
211 |
| - return self._thread |
212 |
| - |
213 |
| - @property |
214 |
| - def run(self): |
215 |
| - return self._run |
216 |
| - |
217 | 207 | async def refresh_run_steps_async(self):
|
218 | 208 | """
|
219 | 209 | Asynchronously refreshes and updates the run steps list.
|
|
0 commit comments