Skip to content

Commit 9a4f436

Browse files
authored
Merge pull request #674 from aaravnavani/on_fail_enum
Make on_fail an enum
2 parents 941d8ce + b0f64d6 commit 9a4f436

23 files changed

+301
-171
lines changed

docs/hub/api_reference_markdown/validators.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ Validates whether the generated code snippet contains any secrets.
165165
```py
166166

167167
guard = Guard.from_string(validators=[
168-
DetectSecrets(on_fail="fix")
168+
DetectSecrets(on_fail=OnFailAction.FIX)
169169
])
170170
guard.parse(
171171
llm_output=code_snippet,

docs/llm_api_wrappers.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,10 +134,10 @@ guard = Guard.from_string(
134134
validators=[
135135
ValidLength(
136136
min=48,
137-
on_fail="fix"
137+
on_fail=OnFailAction.FIX
138138
),
139139
ToxicLanguage(
140-
on_fail="fix"
140+
on_fail=OnFailAction.FIX
141141
)
142142
],
143143
prompt=prompt
@@ -179,10 +179,10 @@ guard = Guard.from_string(
179179
validators=[
180180
ValidLength(
181181
min=48,
182-
on_fail="fix"
182+
on_fail=OnFailAction.FIX
183183
),
184184
ToxicLanguage(
185-
on_fail="fix"
185+
on_fail=OnFailAction.FIX
186186
)
187187
],
188188
prompt=prompt

guardrails/utils/pydantic_utils/v1.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from guardrails.datatypes import String as StringDataType
3434
from guardrails.datatypes import Time as TimeDataType
3535
from guardrails.utils.safe_get import safe_get
36-
from guardrails.validator_base import Validator
36+
from guardrails.validator_base import OnFailAction, Validator
3737
from guardrails.validatorsattr import ValidatorsAttr
3838

3939

@@ -218,7 +218,9 @@ def process_validators(vals, fld):
218218
)
219219
if "validators" not in fld.field_info.extra:
220220
fld.field_info.extra["validators"] = []
221-
fld.field_info.extra["validators"].append((gd_validator, "reask"))
221+
fld.field_info.extra["validators"].append(
222+
(gd_validator, OnFailAction.REASK)
223+
)
222224

223225
model_fields = {}
224226
for field_name, field in model.__fields__.items():

guardrails/utils/pydantic_utils/v2.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from guardrails.datatypes import String as StringDataType
3535
from guardrails.datatypes import Time as TimeDataType
3636
from guardrails.utils.safe_get import safe_get
37-
from guardrails.validator_base import Validator
37+
from guardrails.validator_base import OnFailAction, Validator
3838
from guardrails.validatorsattr import ValidatorsAttr
3939

4040
DataTypeT = TypeVar("DataTypeT", bound=DataType)
@@ -248,7 +248,9 @@ def process_validators(vals, fld):
248248
)
249249
if "validators" not in fld.field_info.json_schema_extra:
250250
fld.json_schema_extra["validators"] = []
251-
fld.json_schema_extra["validators"].append((gd_validator, "reask"))
251+
fld.json_schema_extra["validators"].append(
252+
(gd_validator, OnFailAction.REASK)
253+
)
252254

253255
model_fields = {}
254256
for field_name, field in model.model_fields.items():

guardrails/validator_base.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import inspect
22
from collections import defaultdict
33
from copy import deepcopy
4+
from enum import Enum
45
from string import Template
56
from typing import (
67
Any,
@@ -373,6 +374,16 @@ class FailResult(ValidationResult):
373374
fix_value: Optional[Any] = None
374375

375376

377+
class OnFailAction(str, Enum):
378+
REASK = "reask"
379+
FIX = "fix"
380+
FILTER = "filter"
381+
REFRAIN = "refrain"
382+
NOOP = "noop"
383+
EXCEPTION = "exception"
384+
FIX_REASK = "fix_reask"
385+
386+
376387
@dataclass # type: ignore
377388
class Validator(Runnable):
378389
"""Base class for validators."""
@@ -384,7 +395,9 @@ class Validator(Runnable):
384395
required_metadata_keys = []
385396
_metadata = {}
386397

387-
def __init__(self, on_fail: Optional[Union[Callable, str]] = None, **kwargs):
398+
def __init__(
399+
self, on_fail: Optional[Union[Callable, OnFailAction]] = None, **kwargs
400+
):
388401
# Raise a warning for deprecated validators
389402

390403
# Get class name and rail_alias
@@ -411,8 +424,8 @@ def __init__(self, on_fail: Optional[Union[Callable, str]] = None, **kwargs):
411424
)
412425

413426
if on_fail is None:
414-
on_fail = "noop"
415-
if isinstance(on_fail, str):
427+
on_fail = OnFailAction.NOOP
428+
if isinstance(on_fail, OnFailAction):
416429
self.on_fail_descriptor = on_fail
417430
self.on_fail_method = None
418431
else:

guardrails/validator_service.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os
44
from concurrent.futures import ProcessPoolExecutor
55
from datetime import datetime
6-
from typing import Any, Dict, List, Optional, Tuple
6+
from typing import Any, Dict, List, Optional, Tuple, Union
77

88
from guardrails.classes.history import Iteration
99
from guardrails.datatypes import FieldValidation
@@ -17,6 +17,7 @@
1717
from guardrails.validator_base import (
1818
FailResult,
1919
Filter,
20+
OnFailAction,
2021
PassResult,
2122
Refrain,
2223
ValidationResult,
@@ -59,13 +60,13 @@ def perform_correction(
5960
results: List[FailResult],
6061
value: Any,
6162
validator: Validator,
62-
on_fail_descriptor: str,
63+
on_fail_descriptor: Union[OnFailAction, str],
6364
):
64-
if on_fail_descriptor == "fix":
65+
if on_fail_descriptor == OnFailAction.FIX:
6566
# FIXME: Should we still return fix_value if it is None?
6667
# I think we should warn and return the original value.
6768
return results[0].fix_value
68-
elif on_fail_descriptor == "fix_reask":
69+
elif on_fail_descriptor == OnFailAction.FIX_REASK:
6970
# FIXME: Same thing here
7071
fixed_value = results[0].fix_value
7172
result = self.execute_validator(
@@ -83,21 +84,21 @@ def perform_correction(
8384
if validator.on_fail_method is None:
8485
raise ValueError("on_fail is 'custom' but on_fail_method is None")
8586
return validator.on_fail_method(value, results)
86-
if on_fail_descriptor == "reask":
87+
if on_fail_descriptor == OnFailAction.REASK:
8788
return FieldReAsk(
8889
incorrect_value=value,
8990
fail_results=results,
9091
)
91-
if on_fail_descriptor == "exception":
92+
if on_fail_descriptor == OnFailAction.EXCEPTION:
9293
raise ValidationError(
9394
"Validation failed for field with errors: "
9495
+ ", ".join([result.error_message for result in results])
9596
)
96-
if on_fail_descriptor == "filter":
97+
if on_fail_descriptor == OnFailAction.FILTER:
9798
return Filter()
98-
if on_fail_descriptor == "refrain":
99+
if on_fail_descriptor == OnFailAction.REFRAIN:
99100
return Refrain()
100-
if on_fail_descriptor == "noop":
101+
if on_fail_descriptor == OnFailAction.NOOP:
101102
return value
102103
else:
103104
raise ValueError(
@@ -251,7 +252,11 @@ def group_validators(self, validators):
251252
validators, key=lambda v: (v.on_fail_descriptor, v.override_value_on_pass)
252253
)
253254
for (on_fail_descriptor, override_on_pass), group in groups:
254-
if override_on_pass or on_fail_descriptor in ["fix", "fix_reask", "custom"]:
255+
if override_on_pass or on_fail_descriptor in [
256+
OnFailAction.FIX,
257+
OnFailAction.FIX_REASK,
258+
"custom",
259+
]:
255260
for validator in group:
256261
yield on_fail_descriptor, [validator]
257262
else:

guardrails/validators/detect_secrets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class DetectSecrets(Validator):
4747
```py
4848
4949
guard = Guard.from_string(validators=[
50-
DetectSecrets(on_fail="fix")
50+
DetectSecrets(on_fail=OnFailAction.FIX)
5151
])
5252
guard.parse(
5353
llm_output=code_snippet,

guardrails/validators/ends_with.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from guardrails.logger import logger
44
from guardrails.validator_base import (
55
FailResult,
6+
OnFailAction,
67
PassResult,
78
ValidationResult,
89
Validator,
@@ -26,7 +27,7 @@ class EndsWith(Validator):
2627
end: The required last element.
2728
"""
2829

29-
def __init__(self, end: str, on_fail: str = "fix"):
30+
def __init__(self, end: str, on_fail: OnFailAction = OnFailAction.FIX):
3031
super().__init__(
3132
on_fail=on_fail,
3233
end=end,

guardrails/validators/reading_time.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, Optional
1+
from typing import Any, Callable, Dict, Optional
22

33
from guardrails.logger import logger
44
from guardrails.validator_base import (
@@ -28,7 +28,7 @@ class ReadingTime(Validator):
2828
reading_time: The maximum reading time in minutes.
2929
"""
3030

31-
def __init__(self, reading_time: int, on_fail: Optional[str] = None):
31+
def __init__(self, reading_time: int, on_fail: Optional[Callable] = None):
3232
super().__init__(
3333
on_fail=on_fail,
3434
reading_time=reading_time,

guardrails/validatorsattr.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from guardrails.constants import hub
99
from guardrails.utils.xml_utils import cast_xml_to_string
10-
from guardrails.validator_base import Validator, ValidatorSpec
10+
from guardrails.validator_base import OnFailAction, Validator, ValidatorSpec
1111

1212

1313
class ValidatorsAttr(pydantic.BaseModel):
@@ -176,7 +176,7 @@ def from_xml(
176176
key = cast_xml_to_string(key)
177177
if key.startswith("on-fail-"):
178178
on_fail_handler_name = key[len("on-fail-") :]
179-
on_fail_handler = value
179+
on_fail_handler = OnFailAction(value)
180180
on_fail_handlers[on_fail_handler_name] = on_fail_handler
181181

182182
validators, unregistered_validators = cls.get_validators(

tests/integration_tests/test_assets/entity_extraction/pydantic_models.py

Lines changed: 42 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,20 @@
22

33
from pydantic import BaseModel, Field
44

5+
from guardrails.validator_base import OnFailAction
56
from guardrails.validators import LowerCase, OneLine, TwoWords
67

78

89
class FeeDetailsFilter(BaseModel):
9-
index: int = Field(validators=("1-indexed", "noop"))
10+
index: int = Field(validators=("1-indexed", OnFailAction.NOOP))
1011
name: str = Field(
11-
validators=[LowerCase(on_fail="filter"), TwoWords(on_fail="filter")]
12+
validators=[
13+
LowerCase(on_fail=OnFailAction.FILTER),
14+
TwoWords(on_fail=OnFailAction.FILTER),
15+
]
1216
)
13-
explanation: str = Field(validators=OneLine(on_fail="filter"))
14-
value: float = Field(validators=("percentage", "noop"))
17+
explanation: str = Field(validators=OneLine(on_fail=OnFailAction.FILTER))
18+
value: float = Field(validators=("percentage", OnFailAction.NOOP))
1519

1620

1721
class ContractDetailsFilter(BaseModel):
@@ -25,10 +29,15 @@ class ContractDetailsFilter(BaseModel):
2529

2630

2731
class FeeDetailsFix(BaseModel):
28-
index: int = Field(validators=("1-indexed", "noop"))
29-
name: str = Field(validators=[LowerCase(on_fail="fix"), TwoWords(on_fail="fix")])
30-
explanation: str = Field(validators=OneLine(on_fail="fix"))
31-
value: float = Field(validators=("percentage", "noop"))
32+
index: int = Field(validators=("1-indexed", OnFailAction.NOOP))
33+
name: str = Field(
34+
validators=[
35+
LowerCase(on_fail=OnFailAction.FIX),
36+
TwoWords(on_fail=OnFailAction.FIX),
37+
]
38+
)
39+
explanation: str = Field(validators=OneLine(on_fail=OnFailAction.FIX))
40+
value: float = Field(validators=("percentage", OnFailAction.NOOP))
3241

3342

3443
class ContractDetailsFix(BaseModel):
@@ -42,10 +51,15 @@ class ContractDetailsFix(BaseModel):
4251

4352

4453
class FeeDetailsNoop(BaseModel):
45-
index: int = Field(validators=("1-indexed", "noop"))
46-
name: str = Field(validators=[LowerCase(on_fail="noop"), TwoWords(on_fail="noop")])
47-
explanation: str = Field(validators=OneLine(on_fail="noop"))
48-
value: float = Field(validators=("percentage", "noop"))
54+
index: int = Field(validators=("1-indexed", OnFailAction.NOOP))
55+
name: str = Field(
56+
validators=[
57+
LowerCase(on_fail=OnFailAction.NOOP),
58+
TwoWords(on_fail=OnFailAction.NOOP),
59+
]
60+
)
61+
explanation: str = Field(validators=OneLine(on_fail=OnFailAction.NOOP))
62+
value: float = Field(validators=("percentage", OnFailAction.NOOP))
4963

5064

5165
class ContractDetailsNoop(BaseModel):
@@ -59,10 +73,15 @@ class ContractDetailsNoop(BaseModel):
5973

6074

6175
class FeeDetailsReask(BaseModel):
62-
index: int = Field(validators=("1-indexed", "noop"))
63-
name: str = Field(validators=[LowerCase(on_fail="noop"), TwoWords(on_fail="reask")])
64-
explanation: str = Field(validators=OneLine(on_fail="noop"))
65-
value: float = Field(validators=("percentage", "noop"))
76+
index: int = Field(validators=("1-indexed", OnFailAction.NOOP))
77+
name: str = Field(
78+
validators=[
79+
LowerCase(on_fail=OnFailAction.NOOP),
80+
TwoWords(on_fail=OnFailAction.REASK),
81+
]
82+
)
83+
explanation: str = Field(validators=OneLine(on_fail=OnFailAction.NOOP))
84+
value: float = Field(validators=("percentage", OnFailAction.NOOP))
6685

6786

6887
class ContractDetailsReask(BaseModel):
@@ -76,12 +95,15 @@ class ContractDetailsReask(BaseModel):
7695

7796

7897
class FeeDetailsRefrain(BaseModel):
79-
index: int = Field(validators=("1-indexed", "noop"))
98+
index: int = Field(validators=("1-indexed", OnFailAction.NOOP))
8099
name: str = Field(
81-
validators=[LowerCase(on_fail="refrain"), TwoWords(on_fail="refrain")]
100+
validators=[
101+
LowerCase(on_fail=OnFailAction.REFRAIN),
102+
TwoWords(on_fail=OnFailAction.REFRAIN),
103+
]
82104
)
83-
explanation: str = Field(validators=OneLine(on_fail="refrain"))
84-
value: float = Field(validators=("percentage", "noop"))
105+
explanation: str = Field(validators=OneLine(on_fail=OnFailAction.REFRAIN))
106+
value: float = Field(validators=("percentage", OnFailAction.NOOP))
85107

86108

87109
class ContractDetailsRefrain(BaseModel):

0 commit comments

Comments
 (0)