Skip to content

Commit 070d167

Browse files
authored
Apply custom headers passed to client constructor (#268)
`Client.__init__` takes a `**kwargs` argument that gets passed to the underlying httpx client constructor. However, the current implementation doesn't honor any custom headers that are passed. This PR fixes that. Signed-off-by: Mattt Zmuda <[email protected]>
1 parent fe2398b commit 070d167

File tree

2 files changed

+26
-3
lines changed

2 files changed

+26
-3
lines changed

replicate/client.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -329,9 +329,8 @@ def _build_httpx_client(
329329
timeout: Optional[httpx.Timeout] = None,
330330
**kwargs,
331331
) -> Union[httpx.Client, httpx.AsyncClient]:
332-
headers = {
333-
"User-Agent": f"replicate-python/{__version__}",
334-
}
332+
headers = kwargs.pop("headers", {})
333+
headers["User-Agent"] = f"replicate-python/{__version__}"
335334

336335
if (
337336
api_token := api_token or os.environ.get("REPLICATE_API_TOKEN")

tests/test_client.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,3 +85,27 @@ async def test_server_error_handling():
8585
client._request("GET", "/")
8686
assert "status: 500" in str(exc_info.value)
8787
assert "detail: Server error occurred" in str(exc_info.value)
88+
89+
90+
def test_custom_headers_are_applied():
91+
import replicate
92+
from replicate.exceptions import ReplicateError
93+
94+
custom_headers = {"Custom-Header": "CustomValue"}
95+
96+
def mock_send(request: httpx.Request, **kwargs) -> httpx.Response:
97+
assert "Custom-Header" in request.headers
98+
assert request.headers["Custom-Header"] == "CustomValue"
99+
100+
return httpx.Response(401, json={})
101+
102+
client = replicate.Client(
103+
api_token="dummy_token",
104+
headers=custom_headers,
105+
transport=httpx.MockTransport(mock_send),
106+
)
107+
108+
try:
109+
client.accounts.current()
110+
except ReplicateError:
111+
pass

0 commit comments

Comments
 (0)