Skip to content

Commit 9c1dfb3

Browse files
authored
Merge pull request #493 from guardrails-ai/input-validation
Input validation
2 parents 8c4a3da + 837e1d2 commit 9c1dfb3

File tree

10 files changed

+1020
-80
lines changed

10 files changed

+1020
-80
lines changed

docs/examples/input_validation.ipynb

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"## Input Validation\n",
8+
"\n",
9+
"Guardrails supports validating inputs (prompts, instructions, msg_history) with string validators."
10+
]
11+
},
12+
{
13+
"cell_type": "markdown",
14+
"metadata": {},
15+
"source": [
16+
"In XML, specify the validators on the `prompt` or `instructions` tag, as such:"
17+
]
18+
},
19+
{
20+
"cell_type": "code",
21+
"execution_count": null,
22+
"metadata": {
23+
"is_executing": true
24+
},
25+
"outputs": [],
26+
"source": [
27+
"rail_spec = \"\"\"\n",
28+
"<rail version=\"0.1\">\n",
29+
"<prompt\n",
30+
" validators=\"two-words\"\n",
31+
" on-fail-two-words=\"exception\"\n",
32+
">\n",
33+
"This is not two words\n",
34+
"</prompt>\n",
35+
"<output type=\"string\">\n",
36+
"</output>\n",
37+
"</rail>\n",
38+
"\"\"\"\n",
39+
"\n",
40+
"from guardrails import Guard\n",
41+
"guard = Guard.from_rail_string(rail_spec)"
42+
]
43+
},
44+
{
45+
"cell_type": "markdown",
46+
"metadata": {},
47+
"source": [
48+
"When `fix` is specified as the on-fail handler, the prompt will automatically be amended before calling the LLM.\n",
49+
"\n",
50+
"In any other case (for example, `exception`), a `ValidationException` will be returned in the outcome."
51+
]
52+
},
53+
{
54+
"cell_type": "code",
55+
"execution_count": null,
56+
"metadata": {
57+
"is_executing": true
58+
},
59+
"outputs": [],
60+
"source": [
61+
"import openai\n",
62+
"\n",
63+
"outcome = guard(\n",
64+
" openai.ChatCompletion.create,\n",
65+
")\n",
66+
"outcome.error"
67+
]
68+
},
69+
{
70+
"cell_type": "markdown",
71+
"metadata": {},
72+
"source": [
73+
"When using pydantic to initialize a `Guard`, input validators can be specified by composition, as such:"
74+
]
75+
},
76+
{
77+
"cell_type": "code",
78+
"execution_count": null,
79+
"metadata": {},
80+
"outputs": [],
81+
"source": [
82+
"from guardrails.validators import TwoWords\n",
83+
"from pydantic import BaseModel\n",
84+
"\n",
85+
"\n",
86+
"class Pet(BaseModel):\n",
87+
" name: str\n",
88+
" age: int\n",
89+
"\n",
90+
"\n",
91+
"guard = Guard.from_pydantic(Pet)\n",
92+
"guard.with_prompt_validation([TwoWords(on_fail=\"exception\")])\n",
93+
"\n",
94+
"outcome = guard(\n",
95+
" openai.ChatCompletion.create,\n",
96+
" prompt=\"This is not two words\",\n",
97+
")\n",
98+
"outcome.error"
99+
]
100+
}
101+
],
102+
"metadata": {
103+
"kernelspec": {
104+
"display_name": "Python 3 (ipykernel)",
105+
"language": "python",
106+
"name": "python3"
107+
},
108+
"language_info": {
109+
"codemirror_mode": {
110+
"name": "ipython",
111+
"version": 3
112+
},
113+
"file_extension": ".py",
114+
"mimetype": "text/x-python",
115+
"name": "python",
116+
"nbconvert_exporter": "python",
117+
"pygments_lexer": "ipython3",
118+
"version": "3.11.0"
119+
}
120+
},
121+
"nbformat": 4,
122+
"nbformat_minor": 1
123+
}

guardrails/classes/history/call.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,13 @@ def error(self) -> Optional[str]:
281281
return None
282282
return self.iterations.last.error # type: ignore
283283

284+
@property
285+
def exception(self) -> Optional[Exception]:
286+
"""The exception that interrupted the run."""
287+
if self.iterations.empty():
288+
return None
289+
return self.iterations.last.exception # type: ignore
290+
284291
@property
285292
def failed_validations(self) -> Stack[ValidatorLogs]:
286293
"""The validator logs for any validations that failed during the

guardrails/classes/history/iteration.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,11 @@ def error(self) -> Optional[str]:
111111
this iteration."""
112112
return self.outputs.error
113113

114+
@property
115+
def exception(self) -> Optional[Exception]:
116+
"""The exception that interrupted this iteration."""
117+
return self.outputs.exception
118+
114119
@property
115120
def failed_validations(self) -> List[ValidatorLogs]:
116121
"""The validator logs for any validations that failed during this

guardrails/classes/history/outputs.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ class Outputs(ArbitraryModel):
4343
"that raised and interrupted the process.",
4444
default=None,
4545
)
46+
exception: Optional[Exception] = Field(
47+
description="The exception that interrupted the process.", default=None
48+
)
4649

4750
def _all_empty(self) -> bool:
4851
return (

guardrails/guard.py

Lines changed: 77 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import contextvars
3+
import warnings
34
from typing import (
45
Any,
56
Awaitable,
@@ -27,7 +28,7 @@
2728
from guardrails.prompt import Instructions, Prompt
2829
from guardrails.rail import Rail
2930
from guardrails.run import AsyncRunner, Runner
30-
from guardrails.schema import Schema
31+
from guardrails.schema import Schema, StringSchema
3132
from guardrails.validators import Validator
3233

3334
add_destinations(logger.debug)
@@ -64,9 +65,19 @@ def __init__(
6465
self.base_model = base_model
6566

6667
@property
67-
def input_schema(self) -> Optional[Schema]:
68+
def prompt_schema(self) -> Optional[StringSchema]:
6869
"""Return the input schema."""
69-
return self.rail.input_schema
70+
return self.rail.prompt_schema
71+
72+
@property
73+
def instructions_schema(self) -> Optional[StringSchema]:
74+
"""Return the input schema."""
75+
return self.rail.instructions_schema
76+
77+
@property
78+
def msg_history_schema(self) -> Optional[StringSchema]:
79+
"""Return the input schema."""
80+
return self.rail.msg_history_schema
7081

7182
@property
7283
def output_schema(self) -> Schema:
@@ -377,7 +388,9 @@ def _call_sync(
377388
prompt=prompt_obj,
378389
msg_history=msg_history_obj,
379390
api=get_llm_ask(llm_api, *args, **kwargs),
380-
input_schema=self.input_schema,
391+
prompt_schema=self.prompt_schema,
392+
instructions_schema=self.instructions_schema,
393+
msg_history_schema=self.msg_history_schema,
381394
output_schema=self.output_schema,
382395
num_reasks=num_reasks,
383396
metadata=metadata,
@@ -434,7 +447,9 @@ async def _call_async(
434447
prompt=prompt_obj,
435448
msg_history=msg_history_obj,
436449
api=get_async_llm_ask(llm_api, *args, **kwargs),
437-
input_schema=self.input_schema,
450+
prompt_schema=self.prompt_schema,
451+
instructions_schema=self.instructions_schema,
452+
msg_history_schema=self.msg_history_schema,
438453
output_schema=self.output_schema,
439454
num_reasks=num_reasks,
440455
metadata=metadata,
@@ -610,7 +625,9 @@ def _sync_parse(
610625
prompt=kwargs.pop("prompt", None),
611626
msg_history=kwargs.pop("msg_history", None),
612627
api=get_llm_ask(llm_api, *args, **kwargs) if llm_api else None,
613-
input_schema=None,
628+
prompt_schema=self.prompt_schema,
629+
instructions_schema=self.instructions_schema,
630+
msg_history_schema=self.msg_history_schema,
614631
output_schema=self.output_schema,
615632
num_reasks=num_reasks,
616633
metadata=metadata,
@@ -650,7 +667,9 @@ async def _async_parse(
650667
prompt=kwargs.pop("prompt", None),
651668
msg_history=kwargs.pop("msg_history", None),
652669
api=get_async_llm_ask(llm_api, *args, **kwargs) if llm_api else None,
653-
input_schema=None,
670+
prompt_schema=self.prompt_schema,
671+
instructions_schema=self.instructions_schema,
672+
msg_history_schema=self.msg_history_schema,
654673
output_schema=self.output_schema,
655674
num_reasks=num_reasks,
656675
metadata=metadata,
@@ -663,3 +682,54 @@ async def _async_parse(
663682
)
664683

665684
return ValidationOutcome[OT].from_guard_history(call, error_message)
685+
686+
def with_prompt_validation(
687+
self,
688+
validators: Sequence[Validator],
689+
):
690+
"""Add prompt validation to the Guard.
691+
692+
Args:
693+
validators: The validators to add to the prompt.
694+
"""
695+
if self.rail.prompt_schema:
696+
warnings.warn("Overriding existing prompt validators.")
697+
schema = StringSchema.from_string(
698+
validators=validators,
699+
)
700+
self.rail.prompt_schema = schema
701+
return self
702+
703+
def with_instructions_validation(
704+
self,
705+
validators: Sequence[Validator],
706+
):
707+
"""Add instructions validation to the Guard.
708+
709+
Args:
710+
validators: The validators to add to the instructions.
711+
"""
712+
if self.rail.instructions_schema:
713+
warnings.warn("Overriding existing instructions validators.")
714+
schema = StringSchema.from_string(
715+
validators=validators,
716+
)
717+
self.rail.instructions_schema = schema
718+
return self
719+
720+
def with_msg_history_validation(
721+
self,
722+
validators: Sequence[Validator],
723+
):
724+
"""Add msg_history validation to the Guard.
725+
726+
Args:
727+
validators: The validators to add to the msg_history.
728+
"""
729+
if self.rail.msg_history_schema:
730+
warnings.warn("Overriding existing msg_history validators.")
731+
schema = StringSchema.from_string(
732+
validators=validators,
733+
)
734+
self.rail.msg_history_schema = schema
735+
return self

0 commit comments

Comments
 (0)