Skip to content

Commit bbcda75

Browse files
committed
fix: optimize tool_choice reset logic and fix lint errors
- Refactor tool_choice reset to target only problematic edge cases - Replace manual ModelSettings recreation with dataclasses.replace - Fix line length and error handling lint issues in tests
1 parent d169d79 commit bbcda75

File tree

2 files changed

+137
-147
lines changed

2 files changed

+137
-147
lines changed

src/agents/_run_impl.py

+35-28
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import asyncio
4+
import dataclasses
45
import inspect
56
from collections.abc import Awaitable
67
from dataclasses import dataclass
@@ -51,7 +52,7 @@
5152
from .models.interface import ModelTracing
5253
from .run_context import RunContextWrapper, TContext
5354
from .stream_events import RunItemStreamEvent, StreamEvent
54-
from .tool import ComputerTool, FunctionTool, FunctionToolResult
55+
from .tool import ComputerTool, FunctionTool, FunctionToolResult, Tool
5556
from .tracing import (
5657
SpanError,
5758
Trace,
@@ -208,34 +209,22 @@ async def execute_tools_and_side_effects(
208209
new_step_items.extend(computer_results)
209210

210211
# Reset tool_choice to "auto" after tool execution to prevent infinite loops
211-
if (processed_response.functions or processed_response.computer_actions):
212-
# Reset agent's model_settings
213-
if agent.model_settings.tool_choice == "required" or isinstance(agent.model_settings.tool_choice, str):
214-
# Create a new model_settings to avoid modifying the original shared instance
215-
agent.model_settings = ModelSettings(
216-
temperature=agent.model_settings.temperature,
217-
top_p=agent.model_settings.top_p,
218-
frequency_penalty=agent.model_settings.frequency_penalty,
219-
presence_penalty=agent.model_settings.presence_penalty,
220-
tool_choice="auto", # Reset to auto
221-
parallel_tool_calls=agent.model_settings.parallel_tool_calls,
222-
truncation=agent.model_settings.truncation,
223-
max_tokens=agent.model_settings.max_tokens,
212+
if processed_response.functions or processed_response.computer_actions:
213+
tools = agent.tools
214+
# Only reset in the problematic scenarios where loops are likely unintentional
215+
if cls._should_reset_tool_choice(agent.model_settings, tools):
216+
agent.model_settings = dataclasses.replace(
217+
agent.model_settings,
218+
tool_choice="auto"
224219
)
225-
226-
# Also reset run_config's model_settings if it exists
227-
if run_config.model_settings and (run_config.model_settings.tool_choice == "required" or
228-
isinstance(run_config.model_settings.tool_choice, str)):
229-
# Create a new model_settings for run_config
230-
run_config.model_settings = ModelSettings(
231-
temperature=run_config.model_settings.temperature,
232-
top_p=run_config.model_settings.top_p,
233-
frequency_penalty=run_config.model_settings.frequency_penalty,
234-
presence_penalty=run_config.model_settings.presence_penalty,
235-
tool_choice="auto", # Reset to auto
236-
parallel_tool_calls=run_config.model_settings.parallel_tool_calls,
237-
truncation=run_config.model_settings.truncation,
238-
max_tokens=run_config.model_settings.max_tokens,
220+
221+
if (
222+
run_config.model_settings and
223+
cls._should_reset_tool_choice(run_config.model_settings, tools)
224+
):
225+
run_config.model_settings = dataclasses.replace(
226+
run_config.model_settings,
227+
tool_choice="auto"
239228
)
240229

241230
# Second, check if there are any handoffs
@@ -328,6 +317,24 @@ async def execute_tools_and_side_effects(
328317
next_step=NextStepRunAgain(),
329318
)
330319

320+
@classmethod
321+
def _should_reset_tool_choice(cls, model_settings: ModelSettings, tools: list[Tool]) -> bool:
322+
if model_settings is None or model_settings.tool_choice is None:
323+
return False
324+
325+
# for specific tool choices
326+
if (
327+
isinstance(model_settings.tool_choice, str) and
328+
model_settings.tool_choice not in ["auto", "required", "none"]
329+
):
330+
return True
331+
332+
# for one tool and required tool choice
333+
if model_settings.tool_choice == "required":
334+
return len(tools) == 1
335+
336+
return False
337+
331338
@classmethod
332339
def process_model_response(
333340
cls,

0 commit comments

Comments
 (0)