Skip to content

Commit 2f79ed1

Browse files
committed
test_validators: Add custom on_fail handler test
1 parent dfe1695 commit 2f79ed1

File tree

1 file changed

+102
-1
lines changed

1 file changed

+102
-1
lines changed

tests/unit_tests/test_validators.py

Lines changed: 102 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# noqa:W291
22
import os
3-
from typing import Any, Dict
3+
from typing import Any, Dict, List
44

55
import openai
66
import pytest
@@ -16,6 +16,7 @@
1616
PassResult,
1717
Refrain,
1818
ValidationResult,
19+
ValidatorError,
1920
check_refrain_in_dict,
2021
filter_in_dict,
2122
register_validator,
@@ -503,3 +504,103 @@ def test_detect_secrets():
503504
# Check if mod_value is same as code_snippet,
504505
# as there are no secrets in code_snippet
505506
assert mod_value == NO_SECRETS_CODE_SNIPPET
507+
508+
509+
def custom_fix_on_fail_handler(value: Any, fail_results: List[FailResult]):
510+
return value + " " + value
511+
512+
513+
def custom_reask_on_fail_handler(value: Any, fail_results: List[FailResult]):
514+
return FieldReAsk(incorrect_value=value, fail_results=fail_results)
515+
516+
517+
def custom_exception_on_fail_handler(value: Any, fail_results: List[FailResult]):
518+
raise ValidatorError("Something went wrong!")
519+
520+
521+
def custom_filter_on_fail_handler(value: Any, fail_results: List[FailResult]):
522+
return Filter()
523+
524+
525+
def custom_refrain_on_fail_handler(value: Any, fail_results: List[FailResult]):
526+
return Refrain()
527+
528+
529+
@pytest.mark.parametrize(
530+
"validator_func, expected_result",
531+
[
532+
(
533+
custom_fix_on_fail_handler,
534+
{"pet_type": "dog dog", "name": "Fido"},
535+
),
536+
(
537+
custom_reask_on_fail_handler,
538+
FieldReAsk(
539+
incorrect_value="dog",
540+
path=["pet_type"],
541+
fail_results=[
542+
FailResult(
543+
error_message="must be exactly two words",
544+
fix_value="dog",
545+
)
546+
],
547+
),
548+
),
549+
(
550+
custom_exception_on_fail_handler,
551+
ValidatorError,
552+
),
553+
(
554+
custom_filter_on_fail_handler,
555+
{"name": "Fido"},
556+
),
557+
(
558+
custom_refrain_on_fail_handler,
559+
{},
560+
),
561+
],
562+
)
563+
@pytest.mark.parametrize(
564+
"validator_spec",
565+
[
566+
lambda val_func: TwoWords(on_fail=val_func),
567+
lambda val_func: ("two-words", val_func),
568+
],
569+
)
570+
def test_custom_on_fail_handler(
571+
validator_spec,
572+
validator_func,
573+
expected_result,
574+
):
575+
prompt = """
576+
What kind of pet should I get and what should I name it?
577+
578+
${gr.complete_json_suffix_v2}
579+
"""
580+
581+
output = """
582+
{
583+
"pet_type": "dog",
584+
"name": "Fido"
585+
}
586+
"""
587+
588+
class Pet(BaseModel):
589+
pet_type: str = Field(
590+
description="Species of pet", validators=[validator_spec(validator_func)]
591+
)
592+
name: str = Field(description="a unique pet name")
593+
594+
guard = Guard.from_pydantic(output_class=Pet, prompt=prompt)
595+
if isinstance(expected_result, type) and issubclass(expected_result, Exception):
596+
with pytest.raises(expected_result):
597+
guard.parse(output)
598+
else:
599+
validated_output = guard.parse(output, num_reasks=0)
600+
if isinstance(expected_result, FieldReAsk):
601+
assert (
602+
guard.guard_state.all_histories[0].history[0].reasks[0]
603+
== expected_result
604+
)
605+
else:
606+
assert validated_output == expected_result

0 commit comments

Comments
 (0)