Skip to content

Commit 53d964d

Browse files
committed
fix(json schema): unwrap allOfs with one entry
1 parent 1a388a1 commit 53d964d

File tree

2 files changed

+70
-3
lines changed

2 files changed

+70
-3
lines changed

Diff for: src/openai/lib/_pydantic.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,13 @@ def _ensure_strict_json_schema(
5353
# intersections
5454
all_of = json_schema.get("allOf")
5555
if is_list(all_of):
56-
json_schema["allOf"] = [
57-
_ensure_strict_json_schema(entry, path=(*path, "anyOf", str(i))) for i, entry in enumerate(all_of)
58-
]
56+
if len(all_of) == 1:
57+
json_schema.update(_ensure_strict_json_schema(all_of[0], path=(*path, "allOf", "0")))
58+
json_schema.pop("allOf")
59+
else:
60+
json_schema["allOf"] = [
61+
_ensure_strict_json_schema(entry, path=(*path, "allOf", str(i))) for i, entry in enumerate(all_of)
62+
]
5963

6064
defs = json_schema.get("$defs")
6165
if is_dict(defs):

Diff for: tests/lib/test_pydantic.py

+63
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
from __future__ import annotations
22

3+
from enum import Enum
4+
5+
from pydantic import Field, BaseModel
36
from inline_snapshot import snapshot
47

58
import openai
@@ -161,3 +164,63 @@ def test_most_types() -> None:
161164
},
162165
}
163166
)
167+
168+
169+
class Color(Enum):
170+
RED = "red"
171+
BLUE = "blue"
172+
GREEN = "green"
173+
174+
175+
class ColorDetection(BaseModel):
176+
color: Color = Field(description="The detected color")
177+
hex_color_code: str = Field(description="The hex color code of the detected color")
178+
179+
180+
def test_enums() -> None:
181+
if PYDANTIC_V2:
182+
assert openai.pydantic_function_tool(ColorDetection)["function"] == snapshot(
183+
{
184+
"name": "ColorDetection",
185+
"strict": True,
186+
"parameters": {
187+
"$defs": {"Color": {"enum": ["red", "blue", "green"], "title": "Color", "type": "string"}},
188+
"properties": {
189+
"color": {"description": "The detected color", "$ref": "#/$defs/Color"},
190+
"hex_color_code": {
191+
"description": "The hex color code of the detected color",
192+
"title": "Hex Color Code",
193+
"type": "string",
194+
},
195+
},
196+
"required": ["color", "hex_color_code"],
197+
"title": "ColorDetection",
198+
"type": "object",
199+
"additionalProperties": False,
200+
},
201+
}
202+
)
203+
else:
204+
assert openai.pydantic_function_tool(ColorDetection)["function"] == snapshot(
205+
{
206+
"name": "ColorDetection",
207+
"strict": True,
208+
"parameters": {
209+
"properties": {
210+
"color": {"description": "The detected color", "$ref": "#/definitions/Color"},
211+
"hex_color_code": {
212+
"description": "The hex color code of the detected color",
213+
"title": "Hex Color Code",
214+
"type": "string",
215+
},
216+
},
217+
"required": ["color", "hex_color_code"],
218+
"title": "ColorDetection",
219+
"definitions": {
220+
"Color": {"title": "Color", "description": "An enumeration.", "enum": ["red", "blue", "green"]}
221+
},
222+
"type": "object",
223+
"additionalProperties": False,
224+
},
225+
}
226+
)

0 commit comments

Comments
 (0)