Skip to content

Commit 91c5a1a

Browse files
authored
Merge pull request #420 from irgolic:pydantic-v2-impl
Pydantic v2 Implementation
2 parents 5ced4e6 + eafc5d4 commit 91c5a1a

File tree

15 files changed

+776
-70
lines changed

15 files changed

+776
-70
lines changed

.github/workflows/ci.yml

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ jobs:
2020
strategy:
2121
matrix:
2222
python-version: ['3.8', '3.9', '3.10', '3.11']
23-
2423
steps:
2524
- uses: actions/checkout@v2
2625
- name: Set up Python ${{ matrix.python-version }}
@@ -51,7 +50,7 @@ jobs:
5150
strategy:
5251
matrix:
5352
python-version: ['3.8', '3.9', '3.10', '3.11']
54-
53+
pydantic-version: ['1.10.9', '2.4.2']
5554
steps:
5655
- uses: actions/checkout@v2
5756
- name: Set up Python ${{ matrix.python-version }}
@@ -63,7 +62,7 @@ jobs:
6362
uses: actions/cache@v3
6463
with:
6564
path: ~/.cache/pypoetry
66-
key: poetry-cache-${{ runner.os }}-${{ steps.setup_python.outputs.python-version }}-${{ env.POETRY_VERSION }}
65+
key: poetry-cache-${{ runner.os }}-${{ steps.setup_python.outputs.python-version }}-${{ env.POETRY_VERSION }}-${{ matrix.pydantic-version }}
6766

6867
- name: Install Poetry
6968
uses: snok/install-poetry@v1
@@ -72,10 +71,17 @@ jobs:
7271
# TODO: fix errors so that we can run `make dev` instead
7372
run: |
7473
make full
74+
poetry run pip install pydantic==${{ matrix.pydantic-version }}
75+
76+
- if: matrix.pydantic-version == '2.4.2'
77+
name: Static analysis with pyright (ignoring pydantic v1)
78+
run: |
79+
make type-pydantic-v2
7580
76-
- name: Static analysis with pyright
81+
- if: matrix.pydantic-version == '1.10.9'
82+
name: Static analysis with mypy (ignoring pydantic v2)
7783
run: |
78-
make type
84+
make type-pydantic-v1
7985
8086
Pytests:
8187
runs-on: ubuntu-latest
@@ -85,7 +91,7 @@ jobs:
8591
# TODO: fix errors so that we can run both `make dev` and `make full`
8692
# dependencies: ['dev', 'full']
8793
dependencies: ['full']
88-
94+
pydantic-version: ['1.10.9', '2.4.2']
8995
steps:
9096
- uses: actions/checkout@v2
9197
- name: Set up Python ${{ matrix.python-version }}
@@ -97,15 +103,15 @@ jobs:
97103
uses: actions/cache@v3
98104
with:
99105
path: ~/.cache/pypoetry
100-
key: poetry-cache-${{ runner.os }}-${{ steps.setup_python.outputs.python-version }}-${{ env.POETRY_VERSION }}
106+
key: poetry-cache-${{ runner.os }}-${{ steps.setup_python.outputs.python-version }}-${{ env.POETRY_VERSION }}-${{ matrix.pydantic-version }}
101107

102108
- name: Install Poetry
103109
uses: snok/install-poetry@v1
104110

105111
- name: Install Dependencies
106112
run: |
107-
python -m pip install --upgrade pip
108113
make ${{ matrix.dependencies }}
114+
python -m pip install pydantic==${{ matrix.pydantic-version }}
109115
110116
- name: Run Pytests
111117
run: |

Makefile

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,16 @@ autoformat:
88
type:
99
poetry run pyright guardrails/
1010

11+
type-pydantic-v1:
12+
echo '{"exclude": ["guardrails/utils/pydantic_utils/v2.py"]}' > pyrightconfig.json
13+
poetry run pyright guardrails/
14+
rm pyrightconfig.json
15+
16+
type-pydantic-v2:
17+
echo '{"exclude": ["guardrails/utils/pydantic_utils/v1.py"]}' > pyrightconfig.json
18+
poetry run pyright guardrails/
19+
rm pyrightconfig.json
20+
1121
lint:
1222
poetry run isort -c guardrails/ tests/
1323
poetry run black guardrails/ tests/ --check
@@ -44,4 +54,4 @@ all: autoformat type lint docs test
4454
precommit:
4555
# pytest -x -q --no-summary
4656
pyright guardrails/
47-
make lint
57+
make lint

guardrails/utils/parsing_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def get_code_block(
6565

6666
def get_template_variables(template: str) -> List[str]:
6767
if hasattr(Template, "get_identifiers"):
68-
return Template(template).get_identifiers()
68+
return Template(template).get_identifiers() # type: ignore
6969
else:
7070
d = collections.defaultdict(str)
7171
Template(template).safe_substitute(d)

guardrails/utils/pydantic_utils/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,13 @@
1111
convert_pydantic_model_to_openai_fn,
1212
)
1313
else:
14-
raise NotImplementedError(f"Pydantic version {PYDANTIC_VERSION} is not supported.")
14+
from .v2 import (
15+
ArbitraryModel,
16+
add_pydantic_validators_as_guardrails_validators,
17+
add_validator,
18+
convert_pydantic_model_to_datatype,
19+
convert_pydantic_model_to_openai_fn,
20+
)
1521

1622

1723
__all__ = [

guardrails/utils/pydantic_utils/v1.py

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -311,22 +311,22 @@ def convert_pydantic_model_to_datatype(
311311
inner_type = get_args(type_annotation)
312312
if len(inner_type) == 0:
313313
# If the list is empty, we cannot infer the type of the elements
314-
children[field_name] = datatype_to_pydantic_field(
314+
children[field_name] = pydantic_field_to_datatype(
315315
ListDataType,
316316
field,
317317
strict=strict,
318318
)
319+
continue
319320
inner_type = inner_type[0]
320321
if is_pydantic_base_model(inner_type):
321322
child = convert_pydantic_model_to_datatype(inner_type)
322323
else:
323324
inner_target_datatype = field_to_datatype(inner_type)
324-
child = datatype_to_pydantic_field(
325+
child = construct_datatype(
325326
inner_target_datatype,
326-
inner_type,
327327
strict=strict,
328328
)
329-
children[field_name] = datatype_to_pydantic_field(
329+
children[field_name] = pydantic_field_to_datatype(
330330
ListDataType,
331331
field,
332332
children={"item": child},
@@ -349,7 +349,7 @@ def convert_pydantic_model_to_datatype(
349349
strict=strict,
350350
excluded_fields=[discriminator],
351351
)
352-
children[field_name] = datatype_to_pydantic_field(
352+
children[field_name] = pydantic_field_to_datatype(
353353
Choice,
354354
field,
355355
children=choice_children,
@@ -361,31 +361,28 @@ def convert_pydantic_model_to_datatype(
361361
field, datatype=target_datatype, strict=strict
362362
)
363363
else:
364-
children[field_name] = datatype_to_pydantic_field(
364+
children[field_name] = pydantic_field_to_datatype(
365365
target_datatype,
366366
field,
367367
strict=strict,
368368
)
369369

370370
if isinstance(model_field, ModelField):
371-
return datatype_to_pydantic_field(
371+
return pydantic_field_to_datatype(
372372
datatype,
373373
model_field,
374374
children=children,
375375
strict=strict,
376376
)
377377
else:
378-
format_attr = FormatAttr.from_validators([], ObjectDataType.tag, strict)
379-
return datatype(
378+
return construct_datatype(
379+
datatype,
380380
children=children,
381-
format_attr=format_attr,
382-
optional=False,
383381
name=name,
384-
description=None,
385382
)
386383

387384

388-
def datatype_to_pydantic_field(
385+
def pydantic_field_to_datatype(
389386
datatype: Type[T],
390387
field: ModelField,
391388
children: Optional[Dict[str, "DataType"]] = None,
@@ -396,14 +393,38 @@ def datatype_to_pydantic_field(
396393
children = {}
397394

398395
validators = field.field_info.extra.get("validators", [])
399-
format_attr = FormatAttr.from_validators(validators, datatype.tag, strict)
400396

401397
is_optional = field.required is False
402398

403399
name = field.name
404400
description = field.field_info.description
405401

406-
data_type = datatype(
407-
children, format_attr, is_optional, name, description, **kwargs
402+
return construct_datatype(
403+
datatype,
404+
children,
405+
validators,
406+
is_optional,
407+
name,
408+
description,
409+
strict=strict,
410+
**kwargs,
408411
)
409-
return data_type
412+
413+
414+
def construct_datatype(
415+
datatype: Type[T],
416+
children: Optional[Dict[str, Any]] = None,
417+
validators: Optional[typing.List[Validator]] = None,
418+
optional: bool = False,
419+
name: Optional[str] = None,
420+
description: Optional[str] = None,
421+
strict: bool = False,
422+
**kwargs,
423+
) -> T:
424+
if children is None:
425+
children = {}
426+
if validators is None:
427+
validators = []
428+
429+
format_attr = FormatAttr.from_validators(validators, datatype.tag, strict)
430+
return datatype(children, format_attr, optional, name, description, **kwargs)

0 commit comments

Comments
 (0)