|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
| 3 | +from enum import Enum |
| 4 | + |
| 5 | +from pydantic import Field, BaseModel |
3 | 6 | from inline_snapshot import snapshot
|
4 | 7 |
|
5 | 8 | import openai
|
@@ -161,3 +164,63 @@ def test_most_types() -> None:
|
161 | 164 | },
|
162 | 165 | }
|
163 | 166 | )
|
| 167 | + |
| 168 | + |
| 169 | +class Color(Enum): |
| 170 | + RED = "red" |
| 171 | + BLUE = "blue" |
| 172 | + GREEN = "green" |
| 173 | + |
| 174 | + |
| 175 | +class ColorDetection(BaseModel): |
| 176 | + color: Color = Field(description="The detected color") |
| 177 | + hex_color_code: str = Field(description="The hex color code of the detected color") |
| 178 | + |
| 179 | + |
| 180 | +def test_enums() -> None: |
| 181 | + if PYDANTIC_V2: |
| 182 | + assert openai.pydantic_function_tool(ColorDetection)["function"] == snapshot( |
| 183 | + { |
| 184 | + "name": "ColorDetection", |
| 185 | + "strict": True, |
| 186 | + "parameters": { |
| 187 | + "$defs": {"Color": {"enum": ["red", "blue", "green"], "title": "Color", "type": "string"}}, |
| 188 | + "properties": { |
| 189 | + "color": {"description": "The detected color", "$ref": "#/$defs/Color"}, |
| 190 | + "hex_color_code": { |
| 191 | + "description": "The hex color code of the detected color", |
| 192 | + "title": "Hex Color Code", |
| 193 | + "type": "string", |
| 194 | + }, |
| 195 | + }, |
| 196 | + "required": ["color", "hex_color_code"], |
| 197 | + "title": "ColorDetection", |
| 198 | + "type": "object", |
| 199 | + "additionalProperties": False, |
| 200 | + }, |
| 201 | + } |
| 202 | + ) |
| 203 | + else: |
| 204 | + assert openai.pydantic_function_tool(ColorDetection)["function"] == snapshot( |
| 205 | + { |
| 206 | + "name": "ColorDetection", |
| 207 | + "strict": True, |
| 208 | + "parameters": { |
| 209 | + "properties": { |
| 210 | + "color": {"description": "The detected color", "$ref": "#/definitions/Color"}, |
| 211 | + "hex_color_code": { |
| 212 | + "description": "The hex color code of the detected color", |
| 213 | + "title": "Hex Color Code", |
| 214 | + "type": "string", |
| 215 | + }, |
| 216 | + }, |
| 217 | + "required": ["color", "hex_color_code"], |
| 218 | + "title": "ColorDetection", |
| 219 | + "definitions": { |
| 220 | + "Color": {"title": "Color", "description": "An enumeration.", "enum": ["red", "blue", "green"]} |
| 221 | + }, |
| 222 | + "type": "object", |
| 223 | + "additionalProperties": False, |
| 224 | + }, |
| 225 | + } |
| 226 | + ) |
0 commit comments