diff --git a/src/agents/extensions/models/litellm_model.py b/src/agents/extensions/models/litellm_model.py index 0fc277c3..c72664a4 100644 --- a/src/agents/extensions/models/litellm_model.py +++ b/src/agents/extensions/models/litellm_model.py @@ -286,7 +286,7 @@ async def _fetch_response( stream=stream, stream_options=stream_options, reasoning_effort=reasoning_effort, - extra_headers=HEADERS, + extra_headers={**HEADERS, **(model_settings.extra_headers or {})}, api_key=self.api_key, base_url=self.base_url, **extra_kwargs, diff --git a/src/agents/model_settings.py b/src/agents/model_settings.py index ed9a0131..fee92b4e 100644 --- a/src/agents/model_settings.py +++ b/src/agents/model_settings.py @@ -3,7 +3,7 @@ from dataclasses import dataclass, fields, replace from typing import Literal -from openai._types import Body, Query +from openai._types import Body, Headers, Query from openai.types.shared import Reasoning @@ -67,6 +67,10 @@ class ModelSettings: """Additional body fields to provide with the request. Defaults to None if not provided.""" + extra_headers: Headers | None = None + """Additional headers to provide with the request. + Defaults to None if not provided.""" + def resolve(self, override: ModelSettings | None) -> ModelSettings: """Produce a new ModelSettings by overlaying any non-None values from the override on top of this instance.""" diff --git a/src/agents/models/openai_chatcompletions.py b/src/agents/models/openai_chatcompletions.py index 9989c1ee..e0882927 100644 --- a/src/agents/models/openai_chatcompletions.py +++ b/src/agents/models/openai_chatcompletions.py @@ -255,7 +255,7 @@ async def _fetch_response( stream_options=self._non_null_or_not_given(stream_options), store=self._non_null_or_not_given(store), reasoning_effort=self._non_null_or_not_given(reasoning_effort), - extra_headers=HEADERS, + extra_headers={ **HEADERS, **(model_settings.extra_headers or {}) }, extra_query=model_settings.extra_query, extra_body=model_settings.extra_body, metadata=self._non_null_or_not_given(model_settings.metadata), diff --git a/src/agents/models/openai_responses.py b/src/agents/models/openai_responses.py index ab4617d4..5f067296 100644 --- a/src/agents/models/openai_responses.py +++ b/src/agents/models/openai_responses.py @@ -253,7 +253,7 @@ async def _fetch_response( tool_choice=tool_choice, parallel_tool_calls=parallel_tool_calls, stream=stream, - extra_headers=_HEADERS, + extra_headers={**_HEADERS, **(model_settings.extra_headers or {})}, extra_query=model_settings.extra_query, extra_body=model_settings.extra_body, text=response_format, diff --git a/tests/test_extra_headers.py b/tests/test_extra_headers.py new file mode 100644 index 00000000..f29c2540 --- /dev/null +++ b/tests/test_extra_headers.py @@ -0,0 +1,92 @@ +import pytest +from openai.types.chat.chat_completion import ChatCompletion, Choice +from openai.types.chat.chat_completion_message import ChatCompletionMessage + +from agents import ModelSettings, ModelTracing, OpenAIChatCompletionsModel, OpenAIResponsesModel + + +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_extra_headers_passed_to_openai_responses_model(): + """ + Ensure extra_headers in ModelSettings is passed to the OpenAIResponsesModel client. + """ + called_kwargs = {} + + class DummyResponses: + async def create(self, **kwargs): + nonlocal called_kwargs + called_kwargs = kwargs + class DummyResponse: + id = "dummy" + output = [] + usage = type( + "Usage", (), {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0} + )() + return DummyResponse() + + class DummyClient: + def __init__(self): + self.responses = DummyResponses() + + + + model = OpenAIResponsesModel(model="gpt-4", openai_client=DummyClient()) # type: ignore + extra_headers = {"X-Test-Header": "test-value"} + await model.get_response( + system_instructions=None, + input="hi", + model_settings=ModelSettings(extra_headers=extra_headers), + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id=None, + ) + assert "extra_headers" in called_kwargs + assert called_kwargs["extra_headers"]["X-Test-Header"] == "test-value" + + + +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_extra_headers_passed_to_openai_client(): + """ + Ensure extra_headers in ModelSettings is passed to the OpenAI client. + """ + called_kwargs = {} + + class DummyCompletions: + async def create(self, **kwargs): + nonlocal called_kwargs + called_kwargs = kwargs + msg = ChatCompletionMessage(role="assistant", content="Hello") + choice = Choice(index=0, finish_reason="stop", message=msg) + return ChatCompletion( + id="resp-id", + created=0, + model="fake", + object="chat.completion", + choices=[choice], + usage=None, + ) + + class DummyClient: + def __init__(self): + self.chat = type("_Chat", (), {"completions": DummyCompletions()})() + self.base_url = "https://api.openai.com" + + model = OpenAIChatCompletionsModel(model="gpt-4", openai_client=DummyClient()) # type: ignore + extra_headers = {"X-Test-Header": "test-value"} + await model.get_response( + system_instructions=None, + input="hi", + model_settings=ModelSettings(extra_headers=extra_headers), + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id=None, + ) + assert "extra_headers" in called_kwargs + assert called_kwargs["extra_headers"]["X-Test-Header"] == "test-value"