|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
3 | 3 | import asyncio
|
| 4 | +import dataclasses |
4 | 5 | import inspect
|
5 | 6 | from collections.abc import Awaitable
|
6 | 7 | from dataclasses import dataclass
|
|
51 | 52 | from .models.interface import ModelTracing
|
52 | 53 | from .run_context import RunContextWrapper, TContext
|
53 | 54 | from .stream_events import RunItemStreamEvent, StreamEvent
|
54 |
| -from .tool import ComputerTool, FunctionTool, FunctionToolResult |
| 55 | +from .tool import ComputerTool, FunctionTool, FunctionToolResult, Tool |
55 | 56 | from .tracing import (
|
56 | 57 | SpanError,
|
57 | 58 | Trace,
|
@@ -208,34 +209,22 @@ async def execute_tools_and_side_effects(
|
208 | 209 | new_step_items.extend(computer_results)
|
209 | 210 |
|
210 | 211 | # Reset tool_choice to "auto" after tool execution to prevent infinite loops
|
211 |
| - if (processed_response.functions or processed_response.computer_actions): |
212 |
| - # Reset agent's model_settings |
213 |
| - if agent.model_settings.tool_choice == "required" or isinstance(agent.model_settings.tool_choice, str): |
214 |
| - # Create a new model_settings to avoid modifying the original shared instance |
215 |
| - agent.model_settings = ModelSettings( |
216 |
| - temperature=agent.model_settings.temperature, |
217 |
| - top_p=agent.model_settings.top_p, |
218 |
| - frequency_penalty=agent.model_settings.frequency_penalty, |
219 |
| - presence_penalty=agent.model_settings.presence_penalty, |
220 |
| - tool_choice="auto", # Reset to auto |
221 |
| - parallel_tool_calls=agent.model_settings.parallel_tool_calls, |
222 |
| - truncation=agent.model_settings.truncation, |
223 |
| - max_tokens=agent.model_settings.max_tokens, |
| 212 | + if processed_response.functions or processed_response.computer_actions: |
| 213 | + tools = agent.tools |
| 214 | + # Only reset in the problematic scenarios where loops are likely unintentional |
| 215 | + if cls._should_reset_tool_choice(agent.model_settings, tools): |
| 216 | + agent.model_settings = dataclasses.replace( |
| 217 | + agent.model_settings, |
| 218 | + tool_choice="auto" |
224 | 219 | )
|
225 |
| - |
226 |
| - # Also reset run_config's model_settings if it exists |
227 |
| - if run_config.model_settings and (run_config.model_settings.tool_choice == "required" or |
228 |
| - isinstance(run_config.model_settings.tool_choice, str)): |
229 |
| - # Create a new model_settings for run_config |
230 |
| - run_config.model_settings = ModelSettings( |
231 |
| - temperature=run_config.model_settings.temperature, |
232 |
| - top_p=run_config.model_settings.top_p, |
233 |
| - frequency_penalty=run_config.model_settings.frequency_penalty, |
234 |
| - presence_penalty=run_config.model_settings.presence_penalty, |
235 |
| - tool_choice="auto", # Reset to auto |
236 |
| - parallel_tool_calls=run_config.model_settings.parallel_tool_calls, |
237 |
| - truncation=run_config.model_settings.truncation, |
238 |
| - max_tokens=run_config.model_settings.max_tokens, |
| 220 | + |
| 221 | + if ( |
| 222 | + run_config.model_settings and |
| 223 | + cls._should_reset_tool_choice(run_config.model_settings, tools) |
| 224 | + ): |
| 225 | + run_config.model_settings = dataclasses.replace( |
| 226 | + run_config.model_settings, |
| 227 | + tool_choice="auto" |
239 | 228 | )
|
240 | 229 |
|
241 | 230 | # Second, check if there are any handoffs
|
@@ -328,6 +317,24 @@ async def execute_tools_and_side_effects(
|
328 | 317 | next_step=NextStepRunAgain(),
|
329 | 318 | )
|
330 | 319 |
|
| 320 | + @classmethod |
| 321 | + def _should_reset_tool_choice(cls, model_settings: ModelSettings, tools: list[Tool]) -> bool: |
| 322 | + if model_settings is None or model_settings.tool_choice is None: |
| 323 | + return False |
| 324 | + |
| 325 | + # for specific tool choices |
| 326 | + if ( |
| 327 | + isinstance(model_settings.tool_choice, str) and |
| 328 | + model_settings.tool_choice not in ["auto", "required", "none"] |
| 329 | + ): |
| 330 | + return True |
| 331 | + |
| 332 | + # for one tool and required tool choice |
| 333 | + if model_settings.tool_choice == "required": |
| 334 | + return len(tools) == 1 |
| 335 | + |
| 336 | + return False |
| 337 | + |
331 | 338 | @classmethod
|
332 | 339 | def process_model_response(
|
333 | 340 | cls,
|
|
0 commit comments