Skip to content

Commit 746174f

Browse files
chore(internal): slight transform perf improvement (#2284)
1 parent 692fd08 commit 746174f

File tree

2 files changed

+34
-0
lines changed

2 files changed

+34
-0
lines changed

Diff for: src/openai/_utils/_transform.py

+22
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,10 @@ def _maybe_transform_key(key: str, type_: type) -> str:
142142
return key
143143

144144

145+
def _no_transform_needed(annotation: type) -> bool:
146+
return annotation == float or annotation == int
147+
148+
145149
def _transform_recursive(
146150
data: object,
147151
*,
@@ -184,6 +188,15 @@ def _transform_recursive(
184188
return cast(object, data)
185189

186190
inner_type = extract_type_arg(stripped_type, 0)
191+
if _no_transform_needed(inner_type):
192+
# for some types there is no need to transform anything, so we can get a small
193+
# perf boost from skipping that work.
194+
#
195+
# but we still need to convert to a list to ensure the data is json-serializable
196+
if is_list(data):
197+
return data
198+
return list(data)
199+
187200
return [_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data]
188201

189202
if is_union_type(stripped_type):
@@ -332,6 +345,15 @@ async def _async_transform_recursive(
332345
return cast(object, data)
333346

334347
inner_type = extract_type_arg(stripped_type, 0)
348+
if _no_transform_needed(inner_type):
349+
# for some types there is no need to transform anything, so we can get a small
350+
# perf boost from skipping that work.
351+
#
352+
# but we still need to convert to a list to ensure the data is json-serializable
353+
if is_list(data):
354+
return data
355+
return list(data)
356+
335357
return [await _async_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data]
336358

337359
if is_union_type(stripped_type):

Diff for: tests/test_transform.py

+12
Original file line numberDiff line numberDiff line change
@@ -432,3 +432,15 @@ async def test_base64_file_input(use_async: bool) -> None:
432432
assert await transform({"foo": io.BytesIO(b"Hello, world!")}, TypedDictBase64Input, use_async) == {
433433
"foo": "SGVsbG8sIHdvcmxkIQ=="
434434
} # type: ignore[comparison-overlap]
435+
436+
437+
@parametrize
438+
@pytest.mark.asyncio
439+
async def test_transform_skipping(use_async: bool) -> None:
440+
# lists of ints are left as-is
441+
data = [1, 2, 3]
442+
assert await transform(data, List[int], use_async) is data
443+
444+
# iterables of ints are converted to a list
445+
data = iter([1, 2, 3])
446+
assert await transform(data, Iterable[int], use_async) == [1, 2, 3]

0 commit comments

Comments
 (0)