Skip to content

Commit a6371ba

Browse files
fix: update module level client (#2185)
1 parent 32efcf3 commit a6371ba

File tree

3 files changed

+384
-0
lines changed

3 files changed

+384
-0
lines changed

Diff for: src/openai/__init__.py

+182
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
22

3+
from __future__ import annotations
4+
5+
from typing_extensions import override
6+
37
from . import types
48
from ._types import NOT_GIVEN, Omit, NoneType, NotGiven, Transport, ProxiesTypes
59
from ._utils import file_from_path
@@ -82,3 +86,181 @@
8286
except (TypeError, AttributeError):
8387
# Some of our exported symbols are builtins which we can't set attributes for.
8488
pass
89+
90+
# ------ Module level client ------
91+
import typing as _t
92+
93+
import httpx as _httpx
94+
95+
from ._base_client import DEFAULT_TIMEOUT, DEFAULT_MAX_RETRIES
96+
97+
api_key: str | None = None
98+
99+
organization: str | None = None
100+
101+
project: str | None = None
102+
103+
base_url: str | _httpx.URL | None = None
104+
105+
timeout: float | Timeout | None = DEFAULT_TIMEOUT
106+
107+
max_retries: int = DEFAULT_MAX_RETRIES
108+
109+
default_headers: _t.Mapping[str, str] | None = None
110+
111+
default_query: _t.Mapping[str, object] | None = None
112+
113+
http_client: _httpx.Client | None = None
114+
115+
116+
class _ModuleClient(OpenAI):
117+
# Note: we have to use type: ignores here as overriding class members
118+
# with properties is technically unsafe but it is fine for our use case
119+
120+
@property # type: ignore
121+
@override
122+
def api_key(self) -> str | None:
123+
return api_key
124+
125+
@api_key.setter # type: ignore
126+
def api_key(self, value: str | None) -> None: # type: ignore
127+
global api_key
128+
129+
api_key = value
130+
131+
@property # type: ignore
132+
@override
133+
def organization(self) -> str | None:
134+
return organization
135+
136+
@organization.setter # type: ignore
137+
def organization(self, value: str | None) -> None: # type: ignore
138+
global organization
139+
140+
organization = value
141+
142+
@property # type: ignore
143+
@override
144+
def project(self) -> str | None:
145+
return project
146+
147+
@project.setter # type: ignore
148+
def project(self, value: str | None) -> None: # type: ignore
149+
global project
150+
151+
project = value
152+
153+
@property
154+
@override
155+
def base_url(self) -> _httpx.URL:
156+
if base_url is not None:
157+
return _httpx.URL(base_url)
158+
159+
return super().base_url
160+
161+
@base_url.setter
162+
def base_url(self, url: _httpx.URL | str) -> None:
163+
super().base_url = url # type: ignore[misc]
164+
165+
@property # type: ignore
166+
@override
167+
def timeout(self) -> float | Timeout | None:
168+
return timeout
169+
170+
@timeout.setter # type: ignore
171+
def timeout(self, value: float | Timeout | None) -> None: # type: ignore
172+
global timeout
173+
174+
timeout = value
175+
176+
@property # type: ignore
177+
@override
178+
def max_retries(self) -> int:
179+
return max_retries
180+
181+
@max_retries.setter # type: ignore
182+
def max_retries(self, value: int) -> None: # type: ignore
183+
global max_retries
184+
185+
max_retries = value
186+
187+
@property # type: ignore
188+
@override
189+
def _custom_headers(self) -> _t.Mapping[str, str] | None:
190+
return default_headers
191+
192+
@_custom_headers.setter # type: ignore
193+
def _custom_headers(self, value: _t.Mapping[str, str] | None) -> None: # type: ignore
194+
global default_headers
195+
196+
default_headers = value
197+
198+
@property # type: ignore
199+
@override
200+
def _custom_query(self) -> _t.Mapping[str, object] | None:
201+
return default_query
202+
203+
@_custom_query.setter # type: ignore
204+
def _custom_query(self, value: _t.Mapping[str, object] | None) -> None: # type: ignore
205+
global default_query
206+
207+
default_query = value
208+
209+
@property # type: ignore
210+
@override
211+
def _client(self) -> _httpx.Client:
212+
return http_client or super()._client
213+
214+
@_client.setter # type: ignore
215+
def _client(self, value: _httpx.Client) -> None: # type: ignore
216+
global http_client
217+
218+
http_client = value
219+
220+
221+
_client: OpenAI | None = None
222+
223+
224+
def _load_client() -> OpenAI: # type: ignore[reportUnusedFunction]
225+
global _client
226+
227+
if _client is None:
228+
_client = _ModuleClient(
229+
api_key=api_key,
230+
organization=organization,
231+
project=project,
232+
base_url=base_url,
233+
timeout=timeout,
234+
max_retries=max_retries,
235+
default_headers=default_headers,
236+
default_query=default_query,
237+
http_client=http_client,
238+
)
239+
return _client
240+
241+
return _client
242+
243+
244+
def _reset_client() -> None: # type: ignore[reportUnusedFunction]
245+
global _client
246+
247+
_client = None
248+
249+
250+
from ._module_client import (
251+
beta as beta,
252+
chat as chat,
253+
audio as audio,
254+
files as files,
255+
client as client,
256+
images as images,
257+
models as models,
258+
batches as batches,
259+
uploads as uploads,
260+
responses as responses,
261+
embeddings as embeddings,
262+
completions as completions,
263+
fine_tuning as fine_tuning,
264+
moderations as moderations,
265+
vector_stores as vector_stores,
266+
)

Diff for: src/openai/_module_client.py

+113
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
2+
3+
from typing_extensions import override
4+
5+
from . import resources, _load_client
6+
from ._utils import LazyProxy
7+
8+
9+
class ChatProxy(LazyProxy[resources.Chat]):
10+
@override
11+
def __load__(self) -> resources.Chat:
12+
return _load_client().chat
13+
14+
15+
class BetaProxy(LazyProxy[resources.Beta]):
16+
@override
17+
def __load__(self) -> resources.Beta:
18+
return _load_client().beta
19+
20+
21+
class FilesProxy(LazyProxy[resources.Files]):
22+
@override
23+
def __load__(self) -> resources.Files:
24+
return _load_client().files
25+
26+
27+
class AudioProxy(LazyProxy[resources.Audio]):
28+
@override
29+
def __load__(self) -> resources.Audio:
30+
return _load_client().audio
31+
32+
33+
class ImagesProxy(LazyProxy[resources.Images]):
34+
@override
35+
def __load__(self) -> resources.Images:
36+
return _load_client().images
37+
38+
39+
class ModelsProxy(LazyProxy[resources.Models]):
40+
@override
41+
def __load__(self) -> resources.Models:
42+
return _load_client().models
43+
44+
45+
class ClientProxy(LazyProxy[resources.OpenAI]):
46+
@override
47+
def __load__(self) -> resources.OpenAI:
48+
return _load_client().client
49+
50+
51+
class BatchesProxy(LazyProxy[resources.Batches]):
52+
@override
53+
def __load__(self) -> resources.Batches:
54+
return _load_client().batches
55+
56+
57+
class UploadsProxy(LazyProxy[resources.Uploads]):
58+
@override
59+
def __load__(self) -> resources.Uploads:
60+
return _load_client().uploads
61+
62+
63+
class ResponsesProxy(LazyProxy[resources.Responses]):
64+
@override
65+
def __load__(self) -> resources.Responses:
66+
return _load_client().responses
67+
68+
69+
class EmbeddingsProxy(LazyProxy[resources.Embeddings]):
70+
@override
71+
def __load__(self) -> resources.Embeddings:
72+
return _load_client().embeddings
73+
74+
75+
class CompletionsProxy(LazyProxy[resources.Completions]):
76+
@override
77+
def __load__(self) -> resources.Completions:
78+
return _load_client().completions
79+
80+
81+
class ModerationsProxy(LazyProxy[resources.Moderations]):
82+
@override
83+
def __load__(self) -> resources.Moderations:
84+
return _load_client().moderations
85+
86+
87+
class FineTuningProxy(LazyProxy[resources.FineTuning]):
88+
@override
89+
def __load__(self) -> resources.FineTuning:
90+
return _load_client().fine_tuning
91+
92+
93+
class VectorStoresProxy(LazyProxy[resources.VectorStores]):
94+
@override
95+
def __load__(self) -> resources.VectorStores:
96+
return _load_client().vector_stores
97+
98+
99+
chat: resources.Chat = ChatProxy().__as_proxied__()
100+
beta: resources.Beta = BetaProxy().__as_proxied__()
101+
files: resources.Files = FilesProxy().__as_proxied__()
102+
audio: resources.Audio = AudioProxy().__as_proxied__()
103+
images: resources.Images = ImagesProxy().__as_proxied__()
104+
models: resources.Models = ModelsProxy().__as_proxied__()
105+
client: resources.OpenAI = ClientProxy().__as_proxied__()
106+
batches: resources.Batches = BatchesProxy().__as_proxied__()
107+
uploads: resources.Uploads = UploadsProxy().__as_proxied__()
108+
responses: resources.Responses = ResponsesProxy().__as_proxied__()
109+
embeddings: resources.Embeddings = EmbeddingsProxy().__as_proxied__()
110+
completions: resources.Completions = CompletionsProxy().__as_proxied__()
111+
moderations: resources.Moderations = ModerationsProxy().__as_proxied__()
112+
fine_tuning: resources.FineTuning = FineTuningProxy().__as_proxied__()
113+
vector_stores: resources.VectorStores = VectorStoresProxy().__as_proxied__()

Diff for: tests/test_module_client.py

+89
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
2+
3+
from __future__ import annotations
4+
5+
import httpx
6+
import pytest
7+
from httpx import URL
8+
9+
import openai
10+
from openai import DEFAULT_TIMEOUT, DEFAULT_MAX_RETRIES
11+
12+
13+
def reset_state() -> None:
14+
openai._reset_client()
15+
openai.api_key = None or "My API Key"
16+
openai.organization = None
17+
openai.project = None
18+
openai.base_url = None
19+
openai.timeout = DEFAULT_TIMEOUT
20+
openai.max_retries = DEFAULT_MAX_RETRIES
21+
openai.default_headers = None
22+
openai.default_query = None
23+
openai.http_client = None
24+
25+
26+
@pytest.fixture(autouse=True)
27+
def reset_state_fixture() -> None:
28+
reset_state()
29+
30+
31+
def test_base_url_option() -> None:
32+
assert openai.base_url is None
33+
assert openai.completions._client.base_url == URL("https://api.openai.com/v1/")
34+
35+
openai.base_url = "http://foo.com"
36+
37+
assert openai.base_url == URL("http://foo.com")
38+
assert openai.completions._client.base_url == URL("http://foo.com")
39+
40+
41+
def test_timeout_option() -> None:
42+
assert openai.timeout == openai.DEFAULT_TIMEOUT
43+
assert openai.completions._client.timeout == openai.DEFAULT_TIMEOUT
44+
45+
openai.timeout = 3
46+
47+
assert openai.timeout == 3
48+
assert openai.completions._client.timeout == 3
49+
50+
51+
def test_max_retries_option() -> None:
52+
assert openai.max_retries == openai.DEFAULT_MAX_RETRIES
53+
assert openai.completions._client.max_retries == openai.DEFAULT_MAX_RETRIES
54+
55+
openai.max_retries = 1
56+
57+
assert openai.max_retries == 1
58+
assert openai.completions._client.max_retries == 1
59+
60+
61+
def test_default_headers_option() -> None:
62+
assert openai.default_headers == None
63+
64+
openai.default_headers = {"Foo": "Bar"}
65+
66+
assert openai.default_headers["Foo"] == "Bar"
67+
assert openai.completions._client.default_headers["Foo"] == "Bar"
68+
69+
70+
def test_default_query_option() -> None:
71+
assert openai.default_query is None
72+
assert openai.completions._client._custom_query == {}
73+
74+
openai.default_query = {"Foo": {"nested": 1}}
75+
76+
assert openai.default_query["Foo"] == {"nested": 1}
77+
assert openai.completions._client._custom_query["Foo"] == {"nested": 1}
78+
79+
80+
def test_http_client_option() -> None:
81+
assert openai.http_client is None
82+
83+
original_http_client = openai.completions._client._client
84+
assert original_http_client is not None
85+
86+
new_client = httpx.Client()
87+
openai.http_client = new_client
88+
89+
assert openai.completions._client._client is new_client

0 commit comments

Comments
 (0)