Skip to content

Input validation #493

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 123 additions & 0 deletions docs/examples/input_validation.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Input Validation\n",
"\n",
"Guardrails supports validating inputs (prompts, instructions, msg_history) with string validators."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In XML, specify the validators on the `prompt` or `instructions` tag, as such:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"is_executing": true
},
"outputs": [],
"source": [
"rail_spec = \"\"\"\n",
"<rail version=\"0.1\">\n",
"<prompt\n",
" validators=\"two-words\"\n",
" on-fail-two-words=\"exception\"\n",
">\n",
"This is not two words\n",
"</prompt>\n",
"<output type=\"string\">\n",
"</output>\n",
"</rail>\n",
"\"\"\"\n",
"\n",
"from guardrails import Guard\n",
"guard = Guard.from_rail_string(rail_spec)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"When `fix` is specified as the on-fail handler, the prompt will automatically be amended before calling the LLM.\n",
"\n",
"In any other case (for example, `exception`), a `ValidationException` will be returned in the outcome."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"is_executing": true
},
"outputs": [],
"source": [
"import openai\n",
"\n",
"outcome = guard(\n",
" openai.ChatCompletion.create,\n",
")\n",
"outcome.error"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"When using pydantic to initialize a `Guard`, input validators can be specified by composition, as such:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from guardrails.validators import TwoWords\n",
"from pydantic import BaseModel\n",
"\n",
"\n",
"class Pet(BaseModel):\n",
" name: str\n",
" age: int\n",
"\n",
"\n",
"guard = Guard.from_pydantic(Pet)\n",
"guard.with_prompt_validation([TwoWords(on_fail=\"exception\")])\n",
"\n",
"outcome = guard(\n",
" openai.ChatCompletion.create,\n",
" prompt=\"This is not two words\",\n",
")\n",
"outcome.error"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.0"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
7 changes: 7 additions & 0 deletions guardrails/classes/history/call.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,13 @@ def error(self) -> Optional[str]:
return None
return self.iterations.last.error # type: ignore

@property
def exception(self) -> Optional[Exception]:
"""The exception that interrupted the run."""
if self.iterations.empty():
return None
return self.iterations.last.exception # type: ignore

@property
def failed_validations(self) -> Stack[ValidatorLogs]:
"""The validator logs for any validations that failed during the
Expand Down
5 changes: 5 additions & 0 deletions guardrails/classes/history/iteration.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,11 @@ def error(self) -> Optional[str]:
this iteration."""
return self.outputs.error

@property
def exception(self) -> Optional[Exception]:
"""The exception that interrupted this iteration."""
return self.outputs.exception

@property
def failed_validations(self) -> List[ValidatorLogs]:
"""The validator logs for any validations that failed during this
Expand Down
3 changes: 3 additions & 0 deletions guardrails/classes/history/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ class Outputs(ArbitraryModel):
"that raised and interrupted the process.",
default=None,
)
exception: Optional[Exception] = Field(
description="The exception that interrupted the process.", default=None
)

def _all_empty(self) -> bool:
return (
Expand Down
84 changes: 77 additions & 7 deletions guardrails/guard.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import contextvars
import warnings
from typing import (
Any,
Awaitable,
Expand Down Expand Up @@ -27,7 +28,7 @@
from guardrails.prompt import Instructions, Prompt
from guardrails.rail import Rail
from guardrails.run import AsyncRunner, Runner
from guardrails.schema import Schema
from guardrails.schema import Schema, StringSchema
from guardrails.validators import Validator

add_destinations(logger.debug)
Expand Down Expand Up @@ -64,9 +65,19 @@ def __init__(
self.base_model = base_model

@property
def input_schema(self) -> Optional[Schema]:
def prompt_schema(self) -> Optional[StringSchema]:
"""Return the input schema."""
return self.rail.input_schema
return self.rail.prompt_schema

@property
def instructions_schema(self) -> Optional[StringSchema]:
"""Return the input schema."""
return self.rail.instructions_schema

@property
def msg_history_schema(self) -> Optional[StringSchema]:
"""Return the input schema."""
return self.rail.msg_history_schema

@property
def output_schema(self) -> Schema:
Expand Down Expand Up @@ -377,7 +388,9 @@ def _call_sync(
prompt=prompt_obj,
msg_history=msg_history_obj,
api=get_llm_ask(llm_api, *args, **kwargs),
input_schema=self.input_schema,
prompt_schema=self.prompt_schema,
instructions_schema=self.instructions_schema,
msg_history_schema=self.msg_history_schema,
output_schema=self.output_schema,
num_reasks=num_reasks,
metadata=metadata,
Expand Down Expand Up @@ -434,7 +447,9 @@ async def _call_async(
prompt=prompt_obj,
msg_history=msg_history_obj,
api=get_async_llm_ask(llm_api, *args, **kwargs),
input_schema=self.input_schema,
prompt_schema=self.prompt_schema,
instructions_schema=self.instructions_schema,
msg_history_schema=self.msg_history_schema,
output_schema=self.output_schema,
num_reasks=num_reasks,
metadata=metadata,
Expand Down Expand Up @@ -610,7 +625,9 @@ def _sync_parse(
prompt=kwargs.pop("prompt", None),
msg_history=kwargs.pop("msg_history", None),
api=get_llm_ask(llm_api, *args, **kwargs) if llm_api else None,
input_schema=None,
prompt_schema=self.prompt_schema,
instructions_schema=self.instructions_schema,
msg_history_schema=self.msg_history_schema,
output_schema=self.output_schema,
num_reasks=num_reasks,
metadata=metadata,
Expand Down Expand Up @@ -650,7 +667,9 @@ async def _async_parse(
prompt=kwargs.pop("prompt", None),
msg_history=kwargs.pop("msg_history", None),
api=get_async_llm_ask(llm_api, *args, **kwargs) if llm_api else None,
input_schema=None,
prompt_schema=self.prompt_schema,
instructions_schema=self.instructions_schema,
msg_history_schema=self.msg_history_schema,
output_schema=self.output_schema,
num_reasks=num_reasks,
metadata=metadata,
Expand All @@ -663,3 +682,54 @@ async def _async_parse(
)

return ValidationOutcome[OT].from_guard_history(call, error_message)

def with_prompt_validation(
self,
validators: Sequence[Validator],
):
"""Add prompt validation to the Guard.

Args:
validators: The validators to add to the prompt.
"""
if self.rail.prompt_schema:
warnings.warn("Overriding existing prompt validators.")
schema = StringSchema.from_string(
validators=validators,
)
self.rail.prompt_schema = schema
return self

def with_instructions_validation(
self,
validators: Sequence[Validator],
):
"""Add instructions validation to the Guard.

Args:
validators: The validators to add to the instructions.
"""
if self.rail.instructions_schema:
warnings.warn("Overriding existing instructions validators.")
schema = StringSchema.from_string(
validators=validators,
)
self.rail.instructions_schema = schema
return self

def with_msg_history_validation(
self,
validators: Sequence[Validator],
):
"""Add msg_history validation to the Guard.

Args:
validators: The validators to add to the msg_history.
"""
if self.rail.msg_history_schema:
warnings.warn("Overriding existing msg_history validators.")
schema = StringSchema.from_string(
validators=validators,
)
self.rail.msg_history_schema = schema
return self
Loading