Skip to content

Commit 96e7bec

Browse files
committed
Adjust tests for pydantic2
1 parent 05ce6ca commit 96e7bec

File tree

7 files changed

+184
-37
lines changed

7 files changed

+184
-37
lines changed

tests/integration_tests/mock_llm_outputs.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,10 @@ def _invoke_llm(
7878
python_rail.COMPILED_PROMPT_1_WITHOUT_INSTRUCTIONS,
7979
python_rail.COMPILED_INSTRUCTIONS,
8080
): python_rail.LLM_OUTPUT_1_FAIL_GUARDRAILS_VALIDATION,
81+
(
82+
python_rail.COMPILED_PROMPT_1_PYDANTIC_2_WITHOUT_INSTRUCTIONS,
83+
python_rail.COMPILED_INSTRUCTIONS,
84+
): python_rail.LLM_OUTPUT_1_FAIL_GUARDRAILS_VALIDATION,
8185
(
8286
python_rail.COMPILED_PROMPT_2_WITHOUT_INSTRUCTIONS,
8387
python_rail.COMPILED_INSTRUCTIONS,

tests/integration_tests/test_assets/pydantic/validated_response_reask.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from pydantic import BaseModel, Field
55

6+
from guardrails.utils.pydantic_utils import PYDANTIC_VERSION
67
from guardrails.utils.reask_utils import FieldReAsk
78
from guardrails.validators import (
89
FailResult,
@@ -59,14 +60,31 @@ class Person(BaseModel):
5960
"""
6061

6162
name: str
62-
age: int = Field(..., validators=[AgeMustBeBetween0And150(on_fail="reask")])
63-
zip_code: str = Field(
64-
...,
65-
validators=[
66-
ZipCodeMustBeNumeric(on_fail="reask"),
67-
ZipCodeInCalifornia(on_fail="reask"),
68-
],
69-
)
63+
if PYDANTIC_VERSION.startswith("1"):
64+
age: int = Field(..., validators=[AgeMustBeBetween0And150(on_fail="reask")])
65+
zip_code: str = Field(
66+
...,
67+
validators=[
68+
ZipCodeMustBeNumeric(on_fail="reask"),
69+
ZipCodeInCalifornia(on_fail="reask"),
70+
],
71+
)
72+
else:
73+
age: int = Field(
74+
...,
75+
json_schema_extra={
76+
"validators": [AgeMustBeBetween0And150(on_fail="reask")]
77+
},
78+
)
79+
zip_code: str = Field(
80+
...,
81+
json_schema_extra={
82+
"validators": [
83+
ZipCodeMustBeNumeric(on_fail="reask"),
84+
ZipCodeInCalifornia(on_fail="reask"),
85+
],
86+
},
87+
)
7088

7189

7290
class ListOfPeople(BaseModel):

tests/integration_tests/test_assets/python_rail/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
)
88

99
COMPILED_PROMPT_1_WITHOUT_INSTRUCTIONS = reader("compiled_prompt_1.txt")
10+
COMPILED_PROMPT_1_PYDANTIC_2_WITHOUT_INSTRUCTIONS = reader(
11+
"compiled_prompt_1_pydantic_2.txt"
12+
)
1013
COMPILED_PROMPT_2_WITHOUT_INSTRUCTIONS = reader("compiled_prompt_2.txt")
1114
COMPILED_INSTRUCTIONS = reader("compiled_instructions.txt")
1215

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
Provide detailed information about the top 5 grossing movies from Christopher Nolan including release date, duration, budget, whether it's a sequel, website, and contact email.
2+
3+
Given below is XML that describes the information to extract from this document and the tags to extract it into.
4+
5+
<output>
6+
<string name="name" format="is-valid-director"/>
7+
<list name="movies">
8+
<object>
9+
<integer name="rank"/>
10+
<string name="title"/>
11+
<object name="details">
12+
<date name="release_date"/>
13+
<time name="duration"/>
14+
<float name="budget"/>
15+
<bool name="is_sequel" required="false"/>
16+
<string name="website" format="length: min=9 max=100"/>
17+
<string name="contact_email"/>
18+
<choice name="revenue" discriminator="revenue_type">
19+
<case name="box_office">
20+
<float name="gross"/>
21+
<float name="opening_weekend"/>
22+
</case>
23+
<case name="streaming">
24+
<integer name="subscriptions"/>
25+
<float name="subscription_fee"/>
26+
</case>
27+
</choice>
28+
</object>
29+
</object>
30+
</list>
31+
</output>
32+
33+
34+
ONLY return a valid JSON object (no other text is necessary), where the key of the field in JSON is the `name` attribute of the corresponding XML, and the value is of the type specified by the corresponding XML's tag. The JSON MUST conform to the XML format, including any types and format requests e.g. requests for lists, objects and specific types. Be correct and concise. If you are unsure anywhere, enter `null`.

tests/integration_tests/test_python_rail.py

Lines changed: 78 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44

55
import openai
66
import pytest
7-
from pydantic import BaseModel, Field, root_validator, validator
7+
from pydantic import BaseModel, Field
88

99
import guardrails as gd
10-
from guardrails.utils.pydantic_utils import add_validator
10+
from guardrails.utils.pydantic_utils import PYDANTIC_VERSION, add_validator
1111
from guardrails.validators import (
1212
FailResult,
1313
PassResult,
@@ -52,7 +52,16 @@ class BoxOfficeRevenue(BaseModel):
5252
opening_weekend: float
5353

5454
# Field-level validation using Pydantic (not Guardrails)
55-
@validator("gross")
55+
if PYDANTIC_VERSION.startswith("1"):
56+
from pydantic import validator
57+
58+
decorator = validator("gross")
59+
else:
60+
from pydantic import field_validator
61+
62+
decorator = field_validator("gross")
63+
64+
@decorator
5665
def validate_gross(cls, gross):
5766
if gross <= 0:
5867
raise ValueError("Gross revenue must be a positive value")
@@ -68,23 +77,47 @@ class Details(BaseModel):
6877
duration: time
6978
budget: float
7079
is_sequel: bool = Field(default=False)
71-
website: str = Field(validators=[ValidLength(min=9, max=100, on_fail="reask")])
80+
81+
# Root-level validation using Pydantic (Not in Guardrails)
82+
if PYDANTIC_VERSION.startswith("1"):
83+
website: str = Field(
84+
validators=[ValidLength(min=9, max=100, on_fail="reask")]
85+
)
86+
from pydantic import root_validator
87+
88+
@root_validator
89+
def validate_budget_and_gross(cls, values):
90+
budget = values.get("budget")
91+
revenue = values.get("revenue")
92+
if isinstance(revenue, BoxOfficeRevenue):
93+
gross = revenue.gross
94+
if budget >= gross:
95+
raise ValueError("Budget must be less than gross revenue")
96+
return values
97+
98+
else:
99+
website: str = Field(
100+
json_schema_extra={
101+
"validators": [ValidLength(min=9, max=100, on_fail="reask")]
102+
}
103+
)
104+
from pydantic import model_validator
105+
106+
@model_validator(mode="before")
107+
def validate_budget_and_gross(cls, values):
108+
budget = values.get("budget")
109+
revenue = values.get("revenue")
110+
if revenue["revenue_type"] == "box_office":
111+
gross = revenue["gross"]
112+
if budget >= gross:
113+
raise ValueError("Budget must be less than gross revenue")
114+
return values
115+
72116
contact_email: str
73117
revenue: Union[BoxOfficeRevenue, StreamingRevenue] = Field(
74118
..., discriminator="revenue_type"
75119
)
76120

77-
# Root-level validation using Pydantic (Not in Guardrails)
78-
@root_validator
79-
def validate_budget_and_gross(cls, values):
80-
budget = values.get("budget")
81-
revenue = values.get("revenue")
82-
if isinstance(revenue, BoxOfficeRevenue):
83-
gross = revenue.gross
84-
if budget >= gross:
85-
raise ValueError("Budget must be less than gross revenue")
86-
return values
87-
88121
class Movie(BaseModel):
89122
rank: int
90123
title: str
@@ -125,9 +158,15 @@ class Director(BaseModel):
125158
# Check that the guard state object has the correct number of re-asks.
126159
assert len(guard_history) == 2
127160

128-
assert guard_history[0].prompt == gd.Prompt(
129-
python_rail.COMPILED_PROMPT_1_WITHOUT_INSTRUCTIONS
130-
)
161+
if PYDANTIC_VERSION.startswith("1"):
162+
assert guard_history[0].prompt == gd.Prompt(
163+
python_rail.COMPILED_PROMPT_1_WITHOUT_INSTRUCTIONS
164+
)
165+
else:
166+
assert guard_history[0].prompt == gd.Prompt(
167+
python_rail.COMPILED_PROMPT_1_PYDANTIC_2_WITHOUT_INSTRUCTIONS
168+
)
169+
131170
assert (
132171
guard_history[0].output == python_rail.LLM_OUTPUT_1_FAIL_GUARDRAILS_VALIDATION
133172
)
@@ -140,20 +179,32 @@ class Director(BaseModel):
140179
== python_rail.LLM_OUTPUT_2_SUCCEED_GUARDRAILS_BUT_FAIL_PYDANTIC_VALIDATION
141180
)
142181

143-
with pytest.raises(ValueError):
144-
Director.parse_raw(
145-
python_rail.LLM_OUTPUT_2_SUCCEED_GUARDRAILS_BUT_FAIL_PYDANTIC_VALIDATION
146-
)
182+
if PYDANTIC_VERSION.startswith("1"):
183+
with pytest.raises(ValueError):
184+
Director.parse_raw(
185+
python_rail.LLM_OUTPUT_2_SUCCEED_GUARDRAILS_BUT_FAIL_PYDANTIC_VALIDATION
186+
)
147187

148-
# The user can take corrective action based on the failed validation.
149-
# Either manipulating the output themselves, taking corrective action
150-
# in their application, or upstreaming their validations into Guardrails.
188+
# The user can take corrective action based on the failed validation.
189+
# Either manipulating the output themselves, taking corrective action
190+
# in their application, or upstreaming their validations into Guardrails.
151191

152-
# The fixed output should pass validation using Pydantic
153-
Director.parse_raw(python_rail.LLM_OUTPUT_3_SUCCEED_GUARDRAILS_AND_PYDANTIC)
192+
# The fixed output should pass validation using Pydantic
193+
Director.parse_raw(python_rail.LLM_OUTPUT_3_SUCCEED_GUARDRAILS_AND_PYDANTIC)
194+
else:
195+
with pytest.raises(ValueError):
196+
Director.model_validate_json(
197+
python_rail.LLM_OUTPUT_2_SUCCEED_GUARDRAILS_BUT_FAIL_PYDANTIC_VALIDATION
198+
)
199+
Director.model_validate_json(
200+
python_rail.LLM_OUTPUT_3_SUCCEED_GUARDRAILS_AND_PYDANTIC
201+
)
154202

155203

204+
@pytest.mark.skipif(not PYDANTIC_VERSION.startswith("1"), reason="Pydantic 1.x only")
156205
def test_python_rail_add_validator(mocker):
206+
from pydantic import root_validator, validator
207+
157208
mocker.patch(
158209
"guardrails.llm_providers.OpenAIChatCallable",
159210
new=MockOpenAIChatCallable,

tests/unit_tests/test_llm_providers.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ async def return_async():
283283

284284
class ReturnTempCallable(Callable):
285285
def __call__(*args, **kwargs) -> Any:
286-
return kwargs.get("temperature")
286+
return ""
287287

288288

289289
@pytest.mark.parametrize(
@@ -295,7 +295,8 @@ def __call__(*args, **kwargs) -> Any:
295295
)
296296
def test_get_llm_ask_temperature(llm_api, args, kwargs, expected_temperature):
297297
result = get_llm_ask(llm_api, *args, **kwargs)
298-
assert result().output == str(expected_temperature)
298+
assert "temperature" in result.init_kwargs
299+
assert result.init_kwargs["temperature"] == expected_temperature
299300

300301

301302
def test_chat_prompt():

tests/unit_tests/utils/test_pydantic_utils.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
1+
import pytest
12
from pydantic import BaseModel, Field
23

34
from guardrails.utils.pydantic_utils import (
5+
PYDANTIC_VERSION,
46
add_pydantic_validators_as_guardrails_validators,
57
add_validator,
68
)
79
from guardrails.validators import FailResult, PassResult, ValidChoices, ValidLength
810

911

12+
@pytest.mark.skipif(
13+
not PYDANTIC_VERSION.startswith("1"),
14+
reason="Tests validators syntax for Pydantic v1",
15+
)
1016
def test_add_pydantic_validators_as_guardrails_validators():
1117
# TODO(shreya): Uncomment when custom validators are supported
1218
# def dummy_validator(name: str):
@@ -75,3 +81,33 @@ class DummyModel(BaseModel):
7581
# validators[3].validate(None, "Bob", None)
7682
# with pytest.raises(EventDetail):
7783
# validators[3].validate(None, "Alex", None)
84+
85+
86+
@pytest.mark.skipif(
87+
not PYDANTIC_VERSION.startswith("2"),
88+
reason="Tests validators syntax for Pydantic v2",
89+
)
90+
def test_add_pydantic_validators_as_guardrails_validators_v2():
91+
class DummyModel(BaseModel):
92+
name: str = Field(..., validators=[ValidLength(min=1, max=10)])
93+
94+
model_fields = add_pydantic_validators_as_guardrails_validators(DummyModel)
95+
name_field = model_fields["name"]
96+
97+
# Should have 1 field
98+
assert len(model_fields) == 1, "Should only have one field"
99+
100+
# Should have 4 validators: 1 from the field, 2 from the add_validator method,
101+
# and 1 from the validator decorator.
102+
validators = name_field.json_schema_extra["validators"]
103+
assert len(validators) == 1, "Should have 1 validator"
104+
105+
# The BaseModel field should not be modified
106+
assert len(DummyModel.model_fields["name"].json_schema_extra["validators"]) == 1
107+
108+
# The first validator should be the ValidLength validator
109+
assert isinstance(
110+
validators[0], ValidLength
111+
), "First validator should be ValidLength"
112+
assert isinstance(validators[0].validate("Beatrice", {}), PassResult)
113+
assert isinstance(validators[0].validate("MrAlexander", {}), FailResult)

0 commit comments

Comments
 (0)