Skip to content

Commit 75f089f

Browse files
stainless-botmegamanics
authored andcommitted
fix(azure): ensure custom options can be passed to copy (openai#858)
1 parent 1455d54 commit 75f089f

File tree

3 files changed

+129
-13
lines changed

3 files changed

+129
-13
lines changed

Diff for: src/openai/_client.py

+8-10
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44

55
import os
66
import asyncio
7-
from typing import Union, Mapping
8-
from typing_extensions import override
7+
from typing import Any, Union, Mapping
8+
from typing_extensions import Self, override
99

1010
import httpx
1111

@@ -164,12 +164,10 @@ def copy(
164164
set_default_headers: Mapping[str, str] | None = None,
165165
default_query: Mapping[str, object] | None = None,
166166
set_default_query: Mapping[str, object] | None = None,
167-
) -> OpenAI:
167+
_extra_kwargs: Mapping[str, Any] = {},
168+
) -> Self:
168169
"""
169170
Create a new client instance re-using the same options given to the current client with optional overriding.
170-
171-
It should be noted that this does not share the underlying httpx client class which may lead
172-
to performance issues.
173171
"""
174172
if default_headers is not None and set_default_headers is not None:
175173
raise ValueError("The `default_headers` and `set_default_headers` arguments are mutually exclusive")
@@ -199,6 +197,7 @@ def copy(
199197
max_retries=max_retries if is_given(max_retries) else self.max_retries,
200198
default_headers=headers,
201199
default_query=params,
200+
**_extra_kwargs,
202201
)
203202

204203
# Alias for `copy` for nicer inline usage, e.g.
@@ -374,12 +373,10 @@ def copy(
374373
set_default_headers: Mapping[str, str] | None = None,
375374
default_query: Mapping[str, object] | None = None,
376375
set_default_query: Mapping[str, object] | None = None,
377-
) -> AsyncOpenAI:
376+
_extra_kwargs: Mapping[str, Any] = {},
377+
) -> Self:
378378
"""
379379
Create a new client instance re-using the same options given to the current client with optional overriding.
380-
381-
It should be noted that this does not share the underlying httpx client class which may lead
382-
to performance issues.
383380
"""
384381
if default_headers is not None and set_default_headers is not None:
385382
raise ValueError("The `default_headers` and `set_default_headers` arguments are mutually exclusive")
@@ -409,6 +406,7 @@ def copy(
409406
max_retries=max_retries if is_given(max_retries) else self.max_retries,
410407
default_headers=headers,
411408
default_query=params,
409+
**_extra_kwargs,
412410
)
413411

414412
# Alias for `copy` for nicer inline usage, e.g.

Diff for: src/openai/lib/azure.py

+91-3
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os
44
import inspect
55
from typing import Any, Union, Mapping, TypeVar, Callable, Awaitable, overload
6-
from typing_extensions import override
6+
from typing_extensions import Self, override
77

88
import httpx
99

@@ -178,7 +178,7 @@ def __init__(
178178
if default_query is None:
179179
default_query = {"api-version": api_version}
180180
else:
181-
default_query = {"api-version": api_version, **default_query}
181+
default_query = {**default_query, "api-version": api_version}
182182

183183
if base_url is None:
184184
if azure_endpoint is None:
@@ -212,9 +212,53 @@ def __init__(
212212
http_client=http_client,
213213
_strict_response_validation=_strict_response_validation,
214214
)
215+
self._api_version = api_version
215216
self._azure_ad_token = azure_ad_token
216217
self._azure_ad_token_provider = azure_ad_token_provider
217218

219+
@override
220+
def copy(
221+
self,
222+
*,
223+
api_key: str | None = None,
224+
organization: str | None = None,
225+
api_version: str | None = None,
226+
azure_ad_token: str | None = None,
227+
azure_ad_token_provider: AzureADTokenProvider | None = None,
228+
base_url: str | httpx.URL | None = None,
229+
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
230+
http_client: httpx.Client | None = None,
231+
max_retries: int | NotGiven = NOT_GIVEN,
232+
default_headers: Mapping[str, str] | None = None,
233+
set_default_headers: Mapping[str, str] | None = None,
234+
default_query: Mapping[str, object] | None = None,
235+
set_default_query: Mapping[str, object] | None = None,
236+
_extra_kwargs: Mapping[str, Any] = {},
237+
) -> Self:
238+
"""
239+
Create a new client instance re-using the same options given to the current client with optional overriding.
240+
"""
241+
return super().copy(
242+
api_key=api_key,
243+
organization=organization,
244+
base_url=base_url,
245+
timeout=timeout,
246+
http_client=http_client,
247+
max_retries=max_retries,
248+
default_headers=default_headers,
249+
set_default_headers=set_default_headers,
250+
default_query=default_query,
251+
set_default_query=set_default_query,
252+
_extra_kwargs={
253+
"api_version": api_version or self._api_version,
254+
"azure_ad_token": azure_ad_token or self._azure_ad_token,
255+
"azure_ad_token_provider": azure_ad_token_provider or self._azure_ad_token_provider,
256+
**_extra_kwargs,
257+
},
258+
)
259+
260+
with_options = copy
261+
218262
def _get_azure_ad_token(self) -> str | None:
219263
if self._azure_ad_token is not None:
220264
return self._azure_ad_token
@@ -367,7 +411,7 @@ def __init__(
367411
if default_query is None:
368412
default_query = {"api-version": api_version}
369413
else:
370-
default_query = {"api-version": api_version, **default_query}
414+
default_query = {**default_query, "api-version": api_version}
371415

372416
if base_url is None:
373417
if azure_endpoint is None:
@@ -401,9 +445,53 @@ def __init__(
401445
http_client=http_client,
402446
_strict_response_validation=_strict_response_validation,
403447
)
448+
self._api_version = api_version
404449
self._azure_ad_token = azure_ad_token
405450
self._azure_ad_token_provider = azure_ad_token_provider
406451

452+
@override
453+
def copy(
454+
self,
455+
*,
456+
api_key: str | None = None,
457+
organization: str | None = None,
458+
api_version: str | None = None,
459+
azure_ad_token: str | None = None,
460+
azure_ad_token_provider: AsyncAzureADTokenProvider | None = None,
461+
base_url: str | httpx.URL | None = None,
462+
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
463+
http_client: httpx.AsyncClient | None = None,
464+
max_retries: int | NotGiven = NOT_GIVEN,
465+
default_headers: Mapping[str, str] | None = None,
466+
set_default_headers: Mapping[str, str] | None = None,
467+
default_query: Mapping[str, object] | None = None,
468+
set_default_query: Mapping[str, object] | None = None,
469+
_extra_kwargs: Mapping[str, Any] = {},
470+
) -> Self:
471+
"""
472+
Create a new client instance re-using the same options given to the current client with optional overriding.
473+
"""
474+
return super().copy(
475+
api_key=api_key,
476+
organization=organization,
477+
base_url=base_url,
478+
timeout=timeout,
479+
http_client=http_client,
480+
max_retries=max_retries,
481+
default_headers=default_headers,
482+
set_default_headers=set_default_headers,
483+
default_query=default_query,
484+
set_default_query=set_default_query,
485+
_extra_kwargs={
486+
"api_version": api_version or self._api_version,
487+
"azure_ad_token": azure_ad_token or self._azure_ad_token,
488+
"azure_ad_token_provider": azure_ad_token_provider or self._azure_ad_token_provider,
489+
**_extra_kwargs,
490+
},
491+
)
492+
493+
with_options = copy
494+
407495
async def _get_azure_ad_token(self) -> str | None:
408496
if self._azure_ad_token is not None:
409497
return self._azure_ad_token

Diff for: tests/lib/test_azure.py

+30
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import Union
2+
from typing_extensions import Literal
23

34
import pytest
45

@@ -34,3 +35,32 @@ def test_implicit_deployment_path(client: Client) -> None:
3435
req.url
3536
== "https://example-resource.azure.openai.com/openai/deployments/my-deployment-model/chat/completions?api-version=2023-07-01"
3637
)
38+
39+
40+
@pytest.mark.parametrize(
41+
"client,method",
42+
[
43+
(sync_client, "copy"),
44+
(sync_client, "with_options"),
45+
(async_client, "copy"),
46+
(async_client, "with_options"),
47+
],
48+
)
49+
def test_client_copying(client: Client, method: Literal["copy", "with_options"]) -> None:
50+
if method == "copy":
51+
copied = client.copy()
52+
else:
53+
copied = client.with_options()
54+
55+
assert copied._custom_query == {"api-version": "2023-07-01"}
56+
57+
58+
@pytest.mark.parametrize(
59+
"client",
60+
[sync_client, async_client],
61+
)
62+
def test_client_copying_override_options(client: Client) -> None:
63+
copied = client.copy(
64+
api_version="2022-05-01",
65+
)
66+
assert copied._custom_query == {"api-version": "2022-05-01"}

0 commit comments

Comments
 (0)