Skip to content

Commit 7853a83

Browse files
chore(internal): split up transforms into sync / async (#1210)
1 parent 3ab6f44 commit 7853a83

File tree

17 files changed

+363
-111
lines changed

17 files changed

+363
-111
lines changed

src/openai/_utils/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -44,5 +44,7 @@
4444
from ._transform import (
4545
PropertyInfo as PropertyInfo,
4646
transform as transform,
47+
async_transform as async_transform,
4748
maybe_transform as maybe_transform,
49+
async_maybe_transform as async_maybe_transform,
4850
)

src/openai/_utils/_transform.py

+123-5
Original file line numberDiff line numberDiff line change
@@ -180,11 +180,7 @@ def _transform_recursive(
180180
if isinstance(data, pydantic.BaseModel):
181181
return model_dump(data, exclude_unset=True)
182182

183-
return _transform_value(data, annotation)
184-
185-
186-
def _transform_value(data: object, type_: type) -> object:
187-
annotated_type = _get_annotated_type(type_)
183+
annotated_type = _get_annotated_type(annotation)
188184
if annotated_type is None:
189185
return data
190186

@@ -222,3 +218,125 @@ def _transform_typeddict(
222218
else:
223219
result[_maybe_transform_key(key, type_)] = _transform_recursive(value, annotation=type_)
224220
return result
221+
222+
223+
async def async_maybe_transform(
224+
data: object,
225+
expected_type: object,
226+
) -> Any | None:
227+
"""Wrapper over `async_transform()` that allows `None` to be passed.
228+
229+
See `async_transform()` for more details.
230+
"""
231+
if data is None:
232+
return None
233+
return await async_transform(data, expected_type)
234+
235+
236+
async def async_transform(
237+
data: _T,
238+
expected_type: object,
239+
) -> _T:
240+
"""Transform dictionaries based off of type information from the given type, for example:
241+
242+
```py
243+
class Params(TypedDict, total=False):
244+
card_id: Required[Annotated[str, PropertyInfo(alias="cardID")]]
245+
246+
247+
transformed = transform({"card_id": "<my card ID>"}, Params)
248+
# {'cardID': '<my card ID>'}
249+
```
250+
251+
Any keys / data that does not have type information given will be included as is.
252+
253+
It should be noted that the transformations that this function does are not represented in the type system.
254+
"""
255+
transformed = await _async_transform_recursive(data, annotation=cast(type, expected_type))
256+
return cast(_T, transformed)
257+
258+
259+
async def _async_transform_recursive(
260+
data: object,
261+
*,
262+
annotation: type,
263+
inner_type: type | None = None,
264+
) -> object:
265+
"""Transform the given data against the expected type.
266+
267+
Args:
268+
annotation: The direct type annotation given to the particular piece of data.
269+
This may or may not be wrapped in metadata types, e.g. `Required[T]`, `Annotated[T, ...]` etc
270+
271+
inner_type: If applicable, this is the "inside" type. This is useful in certain cases where the outside type
272+
is a container type such as `List[T]`. In that case `inner_type` should be set to `T` so that each entry in
273+
the list can be transformed using the metadata from the container type.
274+
275+
Defaults to the same value as the `annotation` argument.
276+
"""
277+
if inner_type is None:
278+
inner_type = annotation
279+
280+
stripped_type = strip_annotated_type(inner_type)
281+
if is_typeddict(stripped_type) and is_mapping(data):
282+
return await _async_transform_typeddict(data, stripped_type)
283+
284+
if (
285+
# List[T]
286+
(is_list_type(stripped_type) and is_list(data))
287+
# Iterable[T]
288+
or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str))
289+
):
290+
inner_type = extract_type_arg(stripped_type, 0)
291+
return [await _async_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data]
292+
293+
if is_union_type(stripped_type):
294+
# For union types we run the transformation against all subtypes to ensure that everything is transformed.
295+
#
296+
# TODO: there may be edge cases where the same normalized field name will transform to two different names
297+
# in different subtypes.
298+
for subtype in get_args(stripped_type):
299+
data = await _async_transform_recursive(data, annotation=annotation, inner_type=subtype)
300+
return data
301+
302+
if isinstance(data, pydantic.BaseModel):
303+
return model_dump(data, exclude_unset=True)
304+
305+
annotated_type = _get_annotated_type(annotation)
306+
if annotated_type is None:
307+
return data
308+
309+
# ignore the first argument as it is the actual type
310+
annotations = get_args(annotated_type)[1:]
311+
for annotation in annotations:
312+
if isinstance(annotation, PropertyInfo) and annotation.format is not None:
313+
return await _async_format_data(data, annotation.format, annotation.format_template)
314+
315+
return data
316+
317+
318+
async def _async_format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object:
319+
if isinstance(data, (date, datetime)):
320+
if format_ == "iso8601":
321+
return data.isoformat()
322+
323+
if format_ == "custom" and format_template is not None:
324+
return data.strftime(format_template)
325+
326+
return data
327+
328+
329+
async def _async_transform_typeddict(
330+
data: Mapping[str, object],
331+
expected_type: type,
332+
) -> Mapping[str, object]:
333+
result: dict[str, object] = {}
334+
annotations = get_type_hints(expected_type, include_extras=True)
335+
for key, value in data.items():
336+
type_ = annotations.get(key)
337+
if type_ is None:
338+
# we do not have a type annotation for this field, leave it as is
339+
result[key] = value
340+
else:
341+
result[_maybe_transform_key(key, type_)] = await _async_transform_recursive(value, annotation=type_)
342+
return result

src/openai/resources/audio/speech.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@
99

1010
from ... import _legacy_response
1111
from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven
12-
from ..._utils import maybe_transform
12+
from ..._utils import (
13+
maybe_transform,
14+
async_maybe_transform,
15+
)
1316
from ..._compat import cached_property
1417
from ..._resource import SyncAPIResource, AsyncAPIResource
1518
from ..._response import (
@@ -161,7 +164,7 @@ async def create(
161164
extra_headers = {"Accept": "application/octet-stream", **(extra_headers or {})}
162165
return await self._post(
163166
"/audio/speech",
164-
body=maybe_transform(
167+
body=await async_maybe_transform(
165168
{
166169
"input": input,
167170
"model": model,

src/openai/resources/audio/transcriptions.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,12 @@
99

1010
from ... import _legacy_response
1111
from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven, FileTypes
12-
from ..._utils import extract_files, maybe_transform, deepcopy_minimal
12+
from ..._utils import (
13+
extract_files,
14+
maybe_transform,
15+
deepcopy_minimal,
16+
async_maybe_transform,
17+
)
1318
from ..._compat import cached_property
1419
from ..._resource import SyncAPIResource, AsyncAPIResource
1520
from ..._response import to_streamed_response_wrapper, async_to_streamed_response_wrapper
@@ -200,7 +205,7 @@ async def create(
200205
extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})}
201206
return await self._post(
202207
"/audio/transcriptions",
203-
body=maybe_transform(body, transcription_create_params.TranscriptionCreateParams),
208+
body=await async_maybe_transform(body, transcription_create_params.TranscriptionCreateParams),
204209
files=files,
205210
options=make_request_options(
206211
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout

src/openai/resources/audio/translations.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,12 @@
99

1010
from ... import _legacy_response
1111
from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven, FileTypes
12-
from ..._utils import extract_files, maybe_transform, deepcopy_minimal
12+
from ..._utils import (
13+
extract_files,
14+
maybe_transform,
15+
deepcopy_minimal,
16+
async_maybe_transform,
17+
)
1318
from ..._compat import cached_property
1419
from ..._resource import SyncAPIResource, AsyncAPIResource
1520
from ..._response import to_streamed_response_wrapper, async_to_streamed_response_wrapper
@@ -174,7 +179,7 @@ async def create(
174179
extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})}
175180
return await self._post(
176181
"/audio/translations",
177-
body=maybe_transform(body, translation_create_params.TranslationCreateParams),
182+
body=await async_maybe_transform(body, translation_create_params.TranslationCreateParams),
178183
files=files,
179184
options=make_request_options(
180185
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout

src/openai/resources/beta/assistants/assistants.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@
1717
AsyncFilesWithStreamingResponse,
1818
)
1919
from ...._types import NOT_GIVEN, Body, Query, Headers, NotGiven
20-
from ...._utils import maybe_transform
20+
from ...._utils import (
21+
maybe_transform,
22+
async_maybe_transform,
23+
)
2124
from ...._compat import cached_property
2225
from ...._resource import SyncAPIResource, AsyncAPIResource
2326
from ...._response import to_streamed_response_wrapper, async_to_streamed_response_wrapper
@@ -410,7 +413,7 @@ async def create(
410413
extra_headers = {"OpenAI-Beta": "assistants=v1", **(extra_headers or {})}
411414
return await self._post(
412415
"/assistants",
413-
body=maybe_transform(
416+
body=await async_maybe_transform(
414417
{
415418
"model": model,
416419
"description": description,
@@ -525,7 +528,7 @@ async def update(
525528
extra_headers = {"OpenAI-Beta": "assistants=v1", **(extra_headers or {})}
526529
return await self._post(
527530
f"/assistants/{assistant_id}",
528-
body=maybe_transform(
531+
body=await async_maybe_transform(
529532
{
530533
"description": description,
531534
"file_ids": file_ids,

src/openai/resources/beta/assistants/files.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88

99
from .... import _legacy_response
1010
from ...._types import NOT_GIVEN, Body, Query, Headers, NotGiven
11-
from ...._utils import maybe_transform
11+
from ...._utils import (
12+
maybe_transform,
13+
async_maybe_transform,
14+
)
1215
from ...._compat import cached_property
1316
from ...._resource import SyncAPIResource, AsyncAPIResource
1417
from ...._response import to_streamed_response_wrapper, async_to_streamed_response_wrapper
@@ -259,7 +262,7 @@ async def create(
259262
extra_headers = {"OpenAI-Beta": "assistants=v1", **(extra_headers or {})}
260263
return await self._post(
261264
f"/assistants/{assistant_id}/files",
262-
body=maybe_transform({"file_id": file_id}, file_create_params.FileCreateParams),
265+
body=await async_maybe_transform({"file_id": file_id}, file_create_params.FileCreateParams),
263266
options=make_request_options(
264267
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
265268
),

src/openai/resources/beta/threads/messages/messages.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@
1717
AsyncFilesWithStreamingResponse,
1818
)
1919
from ....._types import NOT_GIVEN, Body, Query, Headers, NotGiven
20-
from ....._utils import maybe_transform
20+
from ....._utils import (
21+
maybe_transform,
22+
async_maybe_transform,
23+
)
2124
from ....._compat import cached_property
2225
from ....._resource import SyncAPIResource, AsyncAPIResource
2326
from ....._response import to_streamed_response_wrapper, async_to_streamed_response_wrapper
@@ -315,7 +318,7 @@ async def create(
315318
extra_headers = {"OpenAI-Beta": "assistants=v1", **(extra_headers or {})}
316319
return await self._post(
317320
f"/threads/{thread_id}/messages",
318-
body=maybe_transform(
321+
body=await async_maybe_transform(
319322
{
320323
"content": content,
321324
"role": role,
@@ -404,7 +407,7 @@ async def update(
404407
extra_headers = {"OpenAI-Beta": "assistants=v1", **(extra_headers or {})}
405408
return await self._post(
406409
f"/threads/{thread_id}/messages/{message_id}",
407-
body=maybe_transform({"metadata": metadata}, message_update_params.MessageUpdateParams),
410+
body=await async_maybe_transform({"metadata": metadata}, message_update_params.MessageUpdateParams),
408411
options=make_request_options(
409412
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
410413
),

src/openai/resources/beta/threads/runs/runs.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@
1717
AsyncStepsWithStreamingResponse,
1818
)
1919
from ....._types import NOT_GIVEN, Body, Query, Headers, NotGiven
20-
from ....._utils import maybe_transform
20+
from ....._utils import (
21+
maybe_transform,
22+
async_maybe_transform,
23+
)
2124
from ....._compat import cached_property
2225
from ....._resource import SyncAPIResource, AsyncAPIResource
2326
from ....._response import to_streamed_response_wrapper, async_to_streamed_response_wrapper
@@ -430,7 +433,7 @@ async def create(
430433
extra_headers = {"OpenAI-Beta": "assistants=v1", **(extra_headers or {})}
431434
return await self._post(
432435
f"/threads/{thread_id}/runs",
433-
body=maybe_transform(
436+
body=await async_maybe_transform(
434437
{
435438
"assistant_id": assistant_id,
436439
"additional_instructions": additional_instructions,
@@ -521,7 +524,7 @@ async def update(
521524
extra_headers = {"OpenAI-Beta": "assistants=v1", **(extra_headers or {})}
522525
return await self._post(
523526
f"/threads/{thread_id}/runs/{run_id}",
524-
body=maybe_transform({"metadata": metadata}, run_update_params.RunUpdateParams),
527+
body=await async_maybe_transform({"metadata": metadata}, run_update_params.RunUpdateParams),
525528
options=make_request_options(
526529
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
527530
),
@@ -669,7 +672,7 @@ async def submit_tool_outputs(
669672
extra_headers = {"OpenAI-Beta": "assistants=v1", **(extra_headers or {})}
670673
return await self._post(
671674
f"/threads/{thread_id}/runs/{run_id}/submit_tool_outputs",
672-
body=maybe_transform(
675+
body=await async_maybe_transform(
673676
{"tool_outputs": tool_outputs}, run_submit_tool_outputs_params.RunSubmitToolOutputsParams
674677
),
675678
options=make_request_options(

src/openai/resources/beta/threads/threads.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,10 @@
2424
AsyncMessagesWithStreamingResponse,
2525
)
2626
from ...._types import NOT_GIVEN, Body, Query, Headers, NotGiven
27-
from ...._utils import maybe_transform
27+
from ...._utils import (
28+
maybe_transform,
29+
async_maybe_transform,
30+
)
2831
from .runs.runs import Runs, AsyncRuns
2932
from ...._compat import cached_property
3033
from ...._resource import SyncAPIResource, AsyncAPIResource
@@ -342,7 +345,7 @@ async def create(
342345
extra_headers = {"OpenAI-Beta": "assistants=v1", **(extra_headers or {})}
343346
return await self._post(
344347
"/threads",
345-
body=maybe_transform(
348+
body=await async_maybe_transform(
346349
{
347350
"messages": messages,
348351
"metadata": metadata,
@@ -423,7 +426,7 @@ async def update(
423426
extra_headers = {"OpenAI-Beta": "assistants=v1", **(extra_headers or {})}
424427
return await self._post(
425428
f"/threads/{thread_id}",
426-
body=maybe_transform({"metadata": metadata}, thread_update_params.ThreadUpdateParams),
429+
body=await async_maybe_transform({"metadata": metadata}, thread_update_params.ThreadUpdateParams),
427430
options=make_request_options(
428431
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
429432
),
@@ -517,7 +520,7 @@ async def create_and_run(
517520
extra_headers = {"OpenAI-Beta": "assistants=v1", **(extra_headers or {})}
518521
return await self._post(
519522
"/threads/runs",
520-
body=maybe_transform(
523+
body=await async_maybe_transform(
521524
{
522525
"assistant_id": assistant_id,
523526
"instructions": instructions,

src/openai/resources/chat/completions.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,11 @@
99

1010
from ... import _legacy_response
1111
from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven
12-
from ..._utils import required_args, maybe_transform
12+
from ..._utils import (
13+
required_args,
14+
maybe_transform,
15+
async_maybe_transform,
16+
)
1317
from ..._compat import cached_property
1418
from ..._resource import SyncAPIResource, AsyncAPIResource
1519
from ..._response import to_streamed_response_wrapper, async_to_streamed_response_wrapper
@@ -1329,7 +1333,7 @@ async def create(
13291333
) -> ChatCompletion | AsyncStream[ChatCompletionChunk]:
13301334
return await self._post(
13311335
"/chat/completions",
1332-
body=maybe_transform(
1336+
body=await async_maybe_transform(
13331337
{
13341338
"messages": messages,
13351339
"model": model,

0 commit comments

Comments
 (0)