Skip to content

Commit a4b4ecf

Browse files
committed
pydantic: Allow Dict/List field types (fix #319)
1 parent 96e7bec commit a4b4ecf

File tree

3 files changed

+126
-34
lines changed

3 files changed

+126
-34
lines changed

guardrails/utils/pydantic_utils/v1.py

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -311,22 +311,22 @@ def convert_pydantic_model_to_datatype(
311311
inner_type = get_args(type_annotation)
312312
if len(inner_type) == 0:
313313
# If the list is empty, we cannot infer the type of the elements
314-
children[field_name] = datatype_to_pydantic_field(
314+
children[field_name] = pydantic_field_to_datatype(
315315
ListDataType,
316316
field,
317317
strict=strict,
318318
)
319+
continue
319320
inner_type = inner_type[0]
320321
if is_pydantic_base_model(inner_type):
321322
child = convert_pydantic_model_to_datatype(inner_type)
322323
else:
323324
inner_target_datatype = field_to_datatype(inner_type)
324-
child = datatype_to_pydantic_field(
325+
child = construct_datatype(
325326
inner_target_datatype,
326-
inner_type,
327327
strict=strict,
328328
)
329-
children[field_name] = datatype_to_pydantic_field(
329+
children[field_name] = pydantic_field_to_datatype(
330330
ListDataType,
331331
field,
332332
children={"item": child},
@@ -349,7 +349,7 @@ def convert_pydantic_model_to_datatype(
349349
strict=strict,
350350
excluded_fields=[discriminator],
351351
)
352-
children[field_name] = datatype_to_pydantic_field(
352+
children[field_name] = pydantic_field_to_datatype(
353353
Choice,
354354
field,
355355
children=choice_children,
@@ -361,31 +361,28 @@ def convert_pydantic_model_to_datatype(
361361
field, datatype=target_datatype, strict=strict
362362
)
363363
else:
364-
children[field_name] = datatype_to_pydantic_field(
364+
children[field_name] = pydantic_field_to_datatype(
365365
target_datatype,
366366
field,
367367
strict=strict,
368368
)
369369

370370
if isinstance(model_field, ModelField):
371-
return datatype_to_pydantic_field(
371+
return pydantic_field_to_datatype(
372372
datatype,
373373
model_field,
374374
children=children,
375375
strict=strict,
376376
)
377377
else:
378-
format_attr = FormatAttr.from_validators([], ObjectDataType.tag, strict)
379-
return datatype(
378+
return construct_datatype(
379+
datatype,
380380
children=children,
381-
format_attr=format_attr,
382-
optional=False,
383381
name=name,
384-
description=None,
385382
)
386383

387384

388-
def datatype_to_pydantic_field(
385+
def pydantic_field_to_datatype(
389386
datatype: Type[T],
390387
field: ModelField,
391388
children: Optional[Dict[str, "DataType"]] = None,
@@ -396,14 +393,38 @@ def datatype_to_pydantic_field(
396393
children = {}
397394

398395
validators = field.field_info.extra.get("validators", [])
399-
format_attr = FormatAttr.from_validators(validators, datatype.tag, strict)
400396

401397
is_optional = field.required is False
402398

403399
name = field.name
404400
description = field.field_info.description
405401

406-
data_type = datatype(
407-
children, format_attr, is_optional, name, description, **kwargs
402+
return construct_datatype(
403+
datatype,
404+
children,
405+
validators,
406+
is_optional,
407+
name,
408+
description,
409+
strict=strict,
410+
**kwargs,
408411
)
409-
return data_type
412+
413+
414+
def construct_datatype(
415+
datatype: Type[T],
416+
children: Optional[Dict[str, Any]] = None,
417+
validators: Optional[list[Validator]] = None,
418+
optional: bool = False,
419+
name: Optional[str] = None,
420+
description: Optional[str] = None,
421+
strict: bool = False,
422+
**kwargs,
423+
) -> T:
424+
if children is None:
425+
children = {}
426+
if validators is None:
427+
validators = []
428+
429+
format_attr = FormatAttr.from_validators(validators, datatype.tag, strict)
430+
return datatype(children, format_attr, optional, name, description, **kwargs)

guardrails/utils/pydantic_utils/v2.py

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -332,25 +332,25 @@ def convert_pydantic_model_to_datatype(
332332
inner_type = get_args(type_annotation)
333333
if len(inner_type) == 0:
334334
# If the list is empty, we cannot infer the type of the elements
335-
children[field_name] = datatype_to_pydantic_field(
335+
children[field_name] = pydantic_field_to_datatype(
336336
ListDataType,
337337
field,
338338
strict=strict,
339339
)
340+
continue
340341
inner_type = inner_type[0]
341342
if is_pydantic_base_model(inner_type):
342343
child = convert_pydantic_model_to_datatype(
343344
inner_type,
344345
)
345346
else:
346347
inner_target_datatype = field_to_datatype(inner_type)
347-
child = datatype_to_pydantic_field(
348+
child = construct_datatype(
348349
inner_target_datatype,
349-
inner_type,
350350
strict=strict,
351351
name=field_name,
352352
)
353-
children[field_name] = datatype_to_pydantic_field(
353+
children[field_name] = pydantic_field_to_datatype(
354354
ListDataType,
355355
field,
356356
children={"item": child},
@@ -374,7 +374,7 @@ def convert_pydantic_model_to_datatype(
374374
strict=strict,
375375
excluded_fields=[discriminator],
376376
)
377-
children[field_name] = datatype_to_pydantic_field(
377+
children[field_name] = pydantic_field_to_datatype(
378378
Choice,
379379
field,
380380
children=choice_children,
@@ -390,33 +390,30 @@ def convert_pydantic_model_to_datatype(
390390
name=field_name,
391391
)
392392
else:
393-
children[field_name] = datatype_to_pydantic_field(
393+
children[field_name] = pydantic_field_to_datatype(
394394
target_datatype,
395395
field,
396396
strict=strict,
397397
name=field_name,
398398
)
399399

400400
if isinstance(model_field, FieldInfo):
401-
return datatype_to_pydantic_field(
401+
return pydantic_field_to_datatype(
402402
datatype,
403403
model_field,
404404
children=children,
405405
strict=strict,
406406
name=name,
407407
)
408408
else:
409-
format_attr = FormatAttr.from_validators([], ObjectDataType.tag, strict)
410-
return datatype(
409+
return construct_datatype(
410+
datatype,
411411
children=children,
412-
format_attr=format_attr,
413-
optional=False,
414412
name=name,
415-
description=None,
416413
)
417414

418415

419-
def datatype_to_pydantic_field(
416+
def pydantic_field_to_datatype(
420417
datatype: Type[DataTypeT],
421418
field: FieldInfo,
422419
children: Optional[Dict[str, "DataType"]] = None,
@@ -431,15 +428,39 @@ def datatype_to_pydantic_field(
431428
validators = []
432429
else:
433430
validators = field.json_schema_extra.get("validators", [])
434-
format_attr = FormatAttr.from_validators(validators, datatype.tag, strict)
435431

436432
is_optional = not field.is_required()
437433

438434
if field.title is not None:
439435
name = field.title
440436
description = field.description
441437

442-
data_type = datatype(
443-
children, format_attr, is_optional, name, description, **kwargs
438+
return construct_datatype(
439+
datatype,
440+
children,
441+
validators,
442+
is_optional,
443+
name,
444+
description,
445+
strict=strict,
446+
**kwargs,
444447
)
445-
return data_type
448+
449+
450+
def construct_datatype(
451+
datatype: Type[DataTypeT],
452+
children: Optional[Dict[str, Any]] = None,
453+
validators: Optional[list[Validator]] = None,
454+
optional: bool = False,
455+
name: Optional[str] = None,
456+
description: Optional[str] = None,
457+
strict: bool = False,
458+
**kwargs,
459+
) -> DataTypeT:
460+
if children is None:
461+
children = {}
462+
if validators is None:
463+
validators = []
464+
465+
format_attr = FormatAttr.from_validators(validators, datatype.tag, strict)
466+
return datatype(children, format_attr, optional, name, description, **kwargs)

tests/integration_tests/test_pydantic.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
1+
import json
2+
from typing import Dict, List
3+
14
import openai
5+
import pytest
6+
from pydantic import BaseModel
27

38
import guardrails as gd
49

@@ -91,3 +96,48 @@ def test_pydantic_with_full_schema_reask(mocker):
9196
)
9297
assert guard_history[2].output == pydantic.LLM_OUTPUT_FULL_REASK_2
9398
assert guard_history[2].validated_output == pydantic.VALIDATED_OUTPUT_REASK_3
99+
100+
101+
class ContainerModel(BaseModel):
102+
annotated_dict: Dict[str, str] = {}
103+
annotated_dict_in_list: List[Dict[str, str]] = []
104+
annotated_list: List[str] = []
105+
annotated_list_in_dict: Dict[str, List[str]] = {}
106+
107+
108+
class ContainerModel2(BaseModel):
109+
dict_: Dict = {}
110+
dict_in_list: List[Dict] = []
111+
list_: List = []
112+
list_in_dict: Dict[str, List] = {}
113+
114+
115+
@pytest.mark.parametrize(
116+
"model, output",
117+
[
118+
(
119+
ContainerModel,
120+
{
121+
"annotated_dict": {"a": "b"},
122+
"annotated_dict_in_list": [{"a": "b"}],
123+
"annotated_list": ["a"],
124+
"annotated_list_in_dict": {"a": ["b"]},
125+
},
126+
),
127+
(
128+
ContainerModel2,
129+
{
130+
"dict_": {"a": "b"},
131+
"dict_in_list": [{"a": "b"}],
132+
"list_": ["a"],
133+
"list_in_dict": {"a": ["b"]},
134+
},
135+
),
136+
],
137+
)
138+
def test_container_types(model, output):
139+
output_str = json.dumps(output)
140+
141+
guard = gd.Guard.from_pydantic(model)
142+
out = guard.parse(output_str)
143+
assert out == output

0 commit comments

Comments
 (0)