|
1 | 1 | # noqa:W291
|
2 | 2 | import os
|
3 |
| -from typing import Any, Dict |
| 3 | +from typing import Any, Dict, List |
4 | 4 |
|
5 | 5 | import openai
|
6 | 6 | import pytest
|
|
16 | 16 | PassResult,
|
17 | 17 | Refrain,
|
18 | 18 | ValidationResult,
|
| 19 | + ValidatorError, |
19 | 20 | check_refrain_in_dict,
|
20 | 21 | filter_in_dict,
|
21 | 22 | register_validator,
|
@@ -503,3 +504,103 @@ def test_detect_secrets():
|
503 | 504 | # Check if mod_value is same as code_snippet,
|
504 | 505 | # as there are no secrets in code_snippet
|
505 | 506 | assert mod_value == NO_SECRETS_CODE_SNIPPET
|
| 507 | + |
| 508 | + |
| 509 | +def custom_fix_on_fail_handler(value: Any, fail_results: List[FailResult]): |
| 510 | + return value + " " + value |
| 511 | + |
| 512 | + |
| 513 | +def custom_reask_on_fail_handler(value: Any, fail_results: List[FailResult]): |
| 514 | + return FieldReAsk(incorrect_value=value, fail_results=fail_results) |
| 515 | + |
| 516 | + |
| 517 | +def custom_exception_on_fail_handler(value: Any, fail_results: List[FailResult]): |
| 518 | + raise ValidatorError("Something went wrong!") |
| 519 | + |
| 520 | + |
| 521 | +def custom_filter_on_fail_handler(value: Any, fail_results: List[FailResult]): |
| 522 | + return Filter() |
| 523 | + |
| 524 | + |
| 525 | +def custom_refrain_on_fail_handler(value: Any, fail_results: List[FailResult]): |
| 526 | + return Refrain() |
| 527 | + |
| 528 | + |
| 529 | +@pytest.mark.parametrize( |
| 530 | + "validator_func, expected_result", |
| 531 | + [ |
| 532 | + ( |
| 533 | + custom_fix_on_fail_handler, |
| 534 | + {"pet_type": "dog dog", "name": "Fido"}, |
| 535 | + ), |
| 536 | + ( |
| 537 | + custom_reask_on_fail_handler, |
| 538 | + FieldReAsk( |
| 539 | + incorrect_value="dog", |
| 540 | + path=["pet_type"], |
| 541 | + fail_results=[ |
| 542 | + FailResult( |
| 543 | + error_message="must be exactly two words", |
| 544 | + fix_value="dog", |
| 545 | + ) |
| 546 | + ], |
| 547 | + ), |
| 548 | + ), |
| 549 | + ( |
| 550 | + custom_exception_on_fail_handler, |
| 551 | + ValidatorError, |
| 552 | + ), |
| 553 | + ( |
| 554 | + custom_filter_on_fail_handler, |
| 555 | + {"name": "Fido"}, |
| 556 | + ), |
| 557 | + ( |
| 558 | + custom_refrain_on_fail_handler, |
| 559 | + {}, |
| 560 | + ), |
| 561 | + ], |
| 562 | +) |
| 563 | +@pytest.mark.parametrize( |
| 564 | + "validator_spec", |
| 565 | + [ |
| 566 | + lambda val_func: TwoWords(on_fail=val_func), |
| 567 | + lambda val_func: ("two-words", val_func), |
| 568 | + ], |
| 569 | +) |
| 570 | +def test_custom_on_fail_handler( |
| 571 | + validator_spec, |
| 572 | + validator_func, |
| 573 | + expected_result, |
| 574 | +): |
| 575 | + prompt = """ |
| 576 | + What kind of pet should I get and what should I name it? |
| 577 | +
|
| 578 | + ${gr.complete_json_suffix_v2} |
| 579 | + """ |
| 580 | + |
| 581 | + output = """ |
| 582 | + { |
| 583 | + "pet_type": "dog", |
| 584 | + "name": "Fido" |
| 585 | + } |
| 586 | + """ |
| 587 | + |
| 588 | + class Pet(BaseModel): |
| 589 | + pet_type: str = Field( |
| 590 | + description="Species of pet", validators=[validator_spec(validator_func)] |
| 591 | + ) |
| 592 | + name: str = Field(description="a unique pet name") |
| 593 | + |
| 594 | + guard = Guard.from_pydantic(output_class=Pet, prompt=prompt) |
| 595 | + if isinstance(expected_result, type) and issubclass(expected_result, Exception): |
| 596 | + with pytest.raises(expected_result): |
| 597 | + guard.parse(output) |
| 598 | + else: |
| 599 | + validated_output = guard.parse(output, num_reasks=0) |
| 600 | + if isinstance(expected_result, FieldReAsk): |
| 601 | + assert ( |
| 602 | + guard.guard_state.all_histories[0].history[0].reasks[0] |
| 603 | + == expected_result |
| 604 | + ) |
| 605 | + else: |
| 606 | + assert validated_output == expected_result |
0 commit comments