Skip to content

Commit 111fc9e

Browse files
authored
Adding extra_headers parameters to ModelSettings (#550)
1 parent 83ce49e commit 111fc9e

File tree

5 files changed

+100
-4
lines changed

5 files changed

+100
-4
lines changed

src/agents/extensions/models/litellm_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ async def _fetch_response(
286286
stream=stream,
287287
stream_options=stream_options,
288288
reasoning_effort=reasoning_effort,
289-
extra_headers=HEADERS,
289+
extra_headers={**HEADERS, **(model_settings.extra_headers or {})},
290290
api_key=self.api_key,
291291
base_url=self.base_url,
292292
**extra_kwargs,

src/agents/model_settings.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from dataclasses import dataclass, fields, replace
44
from typing import Literal
55

6-
from openai._types import Body, Query
6+
from openai._types import Body, Headers, Query
77
from openai.types.shared import Reasoning
88

99

@@ -67,6 +67,10 @@ class ModelSettings:
6767
"""Additional body fields to provide with the request.
6868
Defaults to None if not provided."""
6969

70+
extra_headers: Headers | None = None
71+
"""Additional headers to provide with the request.
72+
Defaults to None if not provided."""
73+
7074
def resolve(self, override: ModelSettings | None) -> ModelSettings:
7175
"""Produce a new ModelSettings by overlaying any non-None values from the
7276
override on top of this instance."""

src/agents/models/openai_chatcompletions.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ async def _fetch_response(
255255
stream_options=self._non_null_or_not_given(stream_options),
256256
store=self._non_null_or_not_given(store),
257257
reasoning_effort=self._non_null_or_not_given(reasoning_effort),
258-
extra_headers=HEADERS,
258+
extra_headers={ **HEADERS, **(model_settings.extra_headers or {}) },
259259
extra_query=model_settings.extra_query,
260260
extra_body=model_settings.extra_body,
261261
metadata=self._non_null_or_not_given(model_settings.metadata),

src/agents/models/openai_responses.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ async def _fetch_response(
253253
tool_choice=tool_choice,
254254
parallel_tool_calls=parallel_tool_calls,
255255
stream=stream,
256-
extra_headers=_HEADERS,
256+
extra_headers={**_HEADERS, **(model_settings.extra_headers or {})},
257257
extra_query=model_settings.extra_query,
258258
extra_body=model_settings.extra_body,
259259
text=response_format,

tests/test_extra_headers.py

+92
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import pytest
2+
from openai.types.chat.chat_completion import ChatCompletion, Choice
3+
from openai.types.chat.chat_completion_message import ChatCompletionMessage
4+
5+
from agents import ModelSettings, ModelTracing, OpenAIChatCompletionsModel, OpenAIResponsesModel
6+
7+
8+
@pytest.mark.allow_call_model_methods
9+
@pytest.mark.asyncio
10+
async def test_extra_headers_passed_to_openai_responses_model():
11+
"""
12+
Ensure extra_headers in ModelSettings is passed to the OpenAIResponsesModel client.
13+
"""
14+
called_kwargs = {}
15+
16+
class DummyResponses:
17+
async def create(self, **kwargs):
18+
nonlocal called_kwargs
19+
called_kwargs = kwargs
20+
class DummyResponse:
21+
id = "dummy"
22+
output = []
23+
usage = type(
24+
"Usage", (), {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
25+
)()
26+
return DummyResponse()
27+
28+
class DummyClient:
29+
def __init__(self):
30+
self.responses = DummyResponses()
31+
32+
33+
34+
model = OpenAIResponsesModel(model="gpt-4", openai_client=DummyClient()) # type: ignore
35+
extra_headers = {"X-Test-Header": "test-value"}
36+
await model.get_response(
37+
system_instructions=None,
38+
input="hi",
39+
model_settings=ModelSettings(extra_headers=extra_headers),
40+
tools=[],
41+
output_schema=None,
42+
handoffs=[],
43+
tracing=ModelTracing.DISABLED,
44+
previous_response_id=None,
45+
)
46+
assert "extra_headers" in called_kwargs
47+
assert called_kwargs["extra_headers"]["X-Test-Header"] == "test-value"
48+
49+
50+
51+
@pytest.mark.allow_call_model_methods
52+
@pytest.mark.asyncio
53+
async def test_extra_headers_passed_to_openai_client():
54+
"""
55+
Ensure extra_headers in ModelSettings is passed to the OpenAI client.
56+
"""
57+
called_kwargs = {}
58+
59+
class DummyCompletions:
60+
async def create(self, **kwargs):
61+
nonlocal called_kwargs
62+
called_kwargs = kwargs
63+
msg = ChatCompletionMessage(role="assistant", content="Hello")
64+
choice = Choice(index=0, finish_reason="stop", message=msg)
65+
return ChatCompletion(
66+
id="resp-id",
67+
created=0,
68+
model="fake",
69+
object="chat.completion",
70+
choices=[choice],
71+
usage=None,
72+
)
73+
74+
class DummyClient:
75+
def __init__(self):
76+
self.chat = type("_Chat", (), {"completions": DummyCompletions()})()
77+
self.base_url = "https://api.openai.com"
78+
79+
model = OpenAIChatCompletionsModel(model="gpt-4", openai_client=DummyClient()) # type: ignore
80+
extra_headers = {"X-Test-Header": "test-value"}
81+
await model.get_response(
82+
system_instructions=None,
83+
input="hi",
84+
model_settings=ModelSettings(extra_headers=extra_headers),
85+
tools=[],
86+
output_schema=None,
87+
handoffs=[],
88+
tracing=ModelTracing.DISABLED,
89+
previous_response_id=None,
90+
)
91+
assert "extra_headers" in called_kwargs
92+
assert called_kwargs["extra_headers"]["X-Test-Header"] == "test-value"

0 commit comments

Comments
 (0)