From 73a1d093cdc86a89ac4706c9c776141664ae5f4f Mon Sep 17 00:00:00 2001 From: Rafael Irgolic Date: Wed, 29 Nov 2023 17:22:39 +0000 Subject: [PATCH 01/11] Internal plumbing for input validation --- guardrails/guard.py | 15 +++++++--- guardrails/rail.py | 19 +++++++------ guardrails/run.py | 68 ++++++++++++++++++++++++++++++++++++--------- 3 files changed, 77 insertions(+), 25 deletions(-) diff --git a/guardrails/guard.py b/guardrails/guard.py index 66edd4a71..89206cff3 100644 --- a/guardrails/guard.py +++ b/guardrails/guard.py @@ -1,6 +1,7 @@ import asyncio import contextvars import logging +import warnings from typing import ( Any, Awaitable, @@ -22,7 +23,7 @@ from guardrails.prompt import Instructions, Prompt from guardrails.rail import Rail from guardrails.run import AsyncRunner, Runner -from guardrails.schema import Schema +from guardrails.schema import Schema, StringSchema from guardrails.utils.logs_utils import GuardState from guardrails.utils.reask_utils import sub_reasks_with_fixed_values from guardrails.validators import Validator @@ -62,9 +63,14 @@ def __init__( self.base_model = base_model @property - def input_schema(self) -> Optional[Schema]: + def prompt_schema(self) -> Optional[Schema]: """Return the input schema.""" - return self.rail.input_schema + return self.rail.prompt_schema + + @property + def instructions_schema(self) -> Optional[Schema]: + """Return the input schema.""" + return self.rail.instructions_schema @property def output_schema(self) -> Schema: @@ -351,7 +357,8 @@ def _call_sync( prompt=prompt_obj, msg_history=msg_history_obj, api=get_llm_ask(llm_api, *args, **kwargs), - input_schema=self.input_schema, + prompt_schema=self.prompt_schema, + instructions_schema=self.instructions_schema, output_schema=self.output_schema, num_reasks=num_reasks, metadata=metadata, diff --git a/guardrails/rail.py b/guardrails/rail.py index ca97c6fa4..4653eb005 100644 --- a/guardrails/rail.py +++ b/guardrails/rail.py @@ -29,7 +29,8 @@ class Rail: 4. ``, which contains the instructions to be passed to the LLM """ - input_schema: Optional[Schema] + prompt_schema: Optional[Schema] + instructions_schema: Optional[Schema] output_schema: Schema instructions: Optional[Instructions] prompt: Optional[Prompt] @@ -44,8 +45,6 @@ def from_pydantic( reask_prompt: Optional[str] = None, reask_instructions: Optional[str] = None, ): - input_schema = None - output_schema = cls.load_json_schema_from_pydantic( output_class, reask_prompt_template=reask_prompt, @@ -53,7 +52,8 @@ def from_pydantic( ) return cls( - input_schema=input_schema, + prompt_schema=None, + instructions_schema=None, output_schema=output_schema, instructions=cls.load_instructions(instructions, output_schema), prompt=cls.load_prompt(prompt, output_schema), @@ -78,12 +78,15 @@ def from_xml(cls, xml: ET._Element): ) # Load schema + # TODO change this to `prompt_validators` and `instructions_validators` raw_input_schema = xml.find("input") if raw_input_schema is None: # No input schema, so do no input checking. input_schema = None else: input_schema = cls.load_input_schema_from_xml(raw_input_schema) + prompt_schema = None + instructions_schema = None # Load schema raw_output_schema = xml.find("output") @@ -123,7 +126,8 @@ def from_xml(cls, xml: ET._Element): version = cast_xml_to_string(version) return cls( - input_schema=input_schema, + prompt_schema=prompt_schema, + instructions_schema=instructions_schema, output_schema=output_schema, instructions=instructions, prompt=prompt, @@ -140,8 +144,6 @@ def from_string_validators( reask_prompt: Optional[str] = None, reask_instructions: Optional[str] = None, ): - input_schema = None - output_schema = cls.load_string_schema_from_string( validators, description=description, @@ -150,7 +152,8 @@ def from_string_validators( ) return cls( - input_schema=input_schema, + prompt_schema=None, + instructions_schema=None, output_schema=output_schema, instructions=cls.load_instructions(instructions, output_schema), prompt=cls.load_prompt(prompt, output_schema), diff --git a/guardrails/run.py b/guardrails/run.py index f370a53af..f515e363c 100644 --- a/guardrails/run.py +++ b/guardrails/run.py @@ -8,7 +8,7 @@ from guardrails.datatypes import verify_metadata_requirements from guardrails.llm_providers import AsyncPromptCallableBase, PromptCallableBase from guardrails.prompt import Instructions, Prompt -from guardrails.schema import Schema +from guardrails.schema import Schema, StringSchema from guardrails.utils.llm_response import LLMResponse from guardrails.utils.logs_utils import GuardHistory, GuardLogs, GuardState from guardrails.utils.reask_utils import ( @@ -18,6 +18,7 @@ reasks_to_dict, sub_reasks_with_fixed_values, ) +from guardrails.validator_base import ValidatorError logger = logging.getLogger(__name__) actions_logger = logging.getLogger(f"{__name__}.actions") @@ -52,7 +53,8 @@ def __init__( instructions: Optional[Union[str, Instructions]] = None, msg_history: Optional[List[Dict]] = None, api: Optional[PromptCallableBase] = None, - input_schema: Optional[Schema] = None, + prompt_schema: Optional[StringSchema] = None, + instructions_schema: Optional[StringSchema] = None, metadata: Optional[Dict[str, Any]] = None, output: Optional[str] = None, guard_history: Optional[GuardHistory] = None, @@ -88,7 +90,8 @@ def __init__( self.msg_history = None self.api = api - self.input_schema = input_schema + self.prompt_schema = prompt_schema + self.instructions_schema = instructions_schema self.output_schema = output_schema self.guard_state = guard_state self.num_reasks = num_reasks @@ -138,16 +141,18 @@ def __call__(self, prompt_params: Optional[Dict] = None) -> GuardHistory: instructions=self.instructions, prompt=self.prompt, api=self.api, - input_schema=self.input_schema, + prompt_schema=self.prompt_schema, + instructions_schema=self.instructions_schema, output_schema=self.output_schema, num_reasks=self.num_reasks, metadata=self.metadata, ): - instructions, prompt, msg_history, input_schema, output_schema = ( + instructions, prompt, msg_history, prompt_schema, instructions_schema, output_schema = ( self.instructions, self.prompt, self.msg_history, - self.input_schema, + self.prompt_schema, + self.instructions_schema, self.output_schema, ) for index in range(self.num_reasks + 1): @@ -159,7 +164,8 @@ def __call__(self, prompt_params: Optional[Dict] = None) -> GuardHistory: prompt=prompt, msg_history=msg_history, prompt_params=prompt_params, - input_schema=input_schema, + prompt_schema=prompt_schema, + instructions_schema=instructions_schema, output_schema=output_schema, output=self.output if index == 0 else None, ) @@ -186,7 +192,8 @@ def step( prompt: Optional[Prompt], msg_history: Optional[List[Dict]], prompt_params: Dict, - input_schema: Optional[Schema], + prompt_schema: Optional[StringSchema], + instructions_schema: Optional[StringSchema], output_schema: Schema, output: Optional[str] = None, ): @@ -199,7 +206,8 @@ def step( instructions=instructions, prompt=prompt, prompt_params=prompt_params, - input_schema=input_schema, + prompt_schema=prompt_schema, + instructions_schema=instructions_schema, output_schema=output_schema, ): # Prepare: run pre-processing, and input validation. @@ -209,13 +217,15 @@ def step( msg_history = None else: instructions, prompt, msg_history = self.prepare( + guard_logs, # TODO pass something else here index, instructions, prompt, msg_history, prompt_params, api, - input_schema, + prompt_schema, + instructions_schema, output_schema, ) @@ -265,13 +275,15 @@ def step( def prepare( self, + guard_logs: GuardLogs, index: int, instructions: Optional[Instructions], prompt: Optional[Prompt], msg_history: Optional[List[Dict]], prompt_params: Dict, api: Optional[Union[PromptCallableBase, AsyncPromptCallableBase]], - input_schema: Optional[Schema], + prompt_schema: Optional[StringSchema], + instructions_schema: Optional[StringSchema], output_schema: Schema, ) -> Tuple[Optional[Instructions], Optional[Prompt], Optional[List[Dict]]]: """Prepare by running pre-processing and input validation. @@ -293,6 +305,8 @@ def prepare( msg["content"] = msg["content"].format(**prompt_params) prompt, instructions = None, None + + # TODO figure out what to do with msg_history in terms of input validation elif prompt is not None: if isinstance(prompt, str): prompt = Prompt(prompt) @@ -307,6 +321,32 @@ def prepare( instructions, prompt = output_schema.preprocess_prompt( api, instructions, prompt ) + + # validate prompt + if prompt_schema is not None: + validated_prompt = prompt_schema.validate( + guard_logs, prompt.source, self.metadata + ) + if validated_prompt is None: + raise ValidatorError("Prompt validation failed") + if isinstance(validated_prompt, ReAsk): + raise ValidatorError( + f"Prompt validation failed: {validated_prompt}" + ) + prompt = Prompt(validated_prompt) + + # validate instructions + if instructions_schema is not None: + validated_instructions = instructions_schema.validate( + guard_logs, instructions.source, self.metadata + ) + if validated_instructions is None: + raise ValidatorError("Instructions validation failed") + if isinstance(validated_instructions, ReAsk): + raise ValidatorError( + f"Instructions validation failed: {validated_instructions}" + ) + instructions = Instructions(validated_instructions) else: raise ValueError("Prompt or message history must be provided.") @@ -591,14 +631,16 @@ async def async_step( prompt = None msg_history = None else: - instructions, prompt, msg_history = self.prepare( + instructions, prompt, msg_history = await self.async_prepare( + guard_logs, # TODO pass something else here index, instructions, prompt, msg_history, prompt_params, api, - input_schema, + prompt_schema, + instructions_schema, output_schema, ) From ad2bc9fed4fecff080801a485c8fc6c0c5320112 Mon Sep 17 00:00:00 2001 From: Rafael Irgolic Date: Wed, 29 Nov 2023 17:22:48 +0000 Subject: [PATCH 02/11] Input validation via composition --- guardrails/guard.py | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/guardrails/guard.py b/guardrails/guard.py index 89206cff3..cd007063f 100644 --- a/guardrails/guard.py +++ b/guardrails/guard.py @@ -619,3 +619,41 @@ async def _async_parse( ) guard_history = await runner.async_run(prompt_params=prompt_params) return sub_reasks_with_fixed_values(guard_history.validated_output) + + def with_prompt_validation( + self, + validators: Sequence[Validator], + ): + """Add prompt validation to the Guard. + + Args: + validators: The validators to add to the prompt. + """ + if self.rail.prompt_schema: + warnings.warn( + "Overriding existing prompt validators." + ) + schema = StringSchema.from_string( + validators=validators, + ) + self.rail.prompt_schema = schema + return self + + def with_instructions_validation( + self, + validators: Sequence[Validator], + ): + """Add instructions validation to the Guard. + + Args: + validators: The validators to add to the instructions. + """ + if self.rail.instructions_schema: + warnings.warn( + "Overriding existing instructions validators." + ) + schema = StringSchema.from_string( + validators=validators, + ) + self.rail.instructions_schema = schema + return self From 8fb9bcf8aea5a2c31a6f0ba0a47c00f8d6ef39ec Mon Sep 17 00:00:00 2001 From: Rafael Irgolic Date: Thu, 30 Nov 2023 17:39:01 +0000 Subject: [PATCH 03/11] msg_history validation --- guardrails/guard.py | 37 ++++++- guardrails/rail.py | 43 ++++---- guardrails/run.py | 156 ++++++++++++++++++++++++++-- tests/integration_tests/test_run.py | 6 +- tests/unit_tests/test_validators.py | 14 +++ 5 files changed, 222 insertions(+), 34 deletions(-) diff --git a/guardrails/guard.py b/guardrails/guard.py index cd007063f..fcc868b42 100644 --- a/guardrails/guard.py +++ b/guardrails/guard.py @@ -72,6 +72,11 @@ def instructions_schema(self) -> Optional[Schema]: """Return the input schema.""" return self.rail.instructions_schema + @property + def msg_history_schema(self) -> Optional[Schema]: + """Return the input schema.""" + return self.rail.msg_history_schema + @property def output_schema(self) -> Schema: """Return the output schema.""" @@ -359,6 +364,7 @@ def _call_sync( api=get_llm_ask(llm_api, *args, **kwargs), prompt_schema=self.prompt_schema, instructions_schema=self.instructions_schema, + msg_history_schema=self.msg_history_schema, output_schema=self.output_schema, num_reasks=num_reasks, metadata=metadata, @@ -415,7 +421,9 @@ async def _call_async( prompt=prompt_obj, msg_history=msg_history_obj, api=get_async_llm_ask(llm_api, *args, **kwargs), - input_schema=self.input_schema, + prompt_schema=self.prompt_schema, + instructions_schema=self.instructions_schema, + msg_history_schema=self.msg_history_schema, output_schema=self.output_schema, num_reasks=num_reasks, metadata=metadata, @@ -569,7 +577,9 @@ def _sync_parse( prompt=kwargs.pop("prompt", None), msg_history=kwargs.pop("msg_history", None), api=get_llm_ask(llm_api, *args, **kwargs) if llm_api else None, - input_schema=None, + prompt_schema=self.prompt_schema, + instructions_schema=self.instructions_schema, + msg_history_schema=self.msg_history_schema, output_schema=self.output_schema, num_reasks=num_reasks, metadata=metadata, @@ -608,7 +618,9 @@ async def _async_parse( prompt=kwargs.pop("prompt", None), msg_history=kwargs.pop("msg_history", None), api=get_async_llm_ask(llm_api, *args, **kwargs) if llm_api else None, - input_schema=None, + prompt_schema=self.prompt_schema, + instructions_schema=self.instructions_schema, + msg_history_schema=self.msg_history_schema, output_schema=self.output_schema, num_reasks=num_reasks, metadata=metadata, @@ -657,3 +669,22 @@ def with_instructions_validation( ) self.rail.instructions_schema = schema return self + + def with_msg_history_validation( + self, + validators: Sequence[Validator], + ): + """Add msg_history validation to the Guard. + + Args: + validators: The validators to add to the msg_history. + """ + if self.rail.msg_history_schema: + warnings.warn( + "Overriding existing msg_history validators." + ) + schema = StringSchema.from_string( + validators=validators, + ) + self.rail.msg_history_schema = schema + return self diff --git a/guardrails/rail.py b/guardrails/rail.py index 4653eb005..07c71c5c4 100644 --- a/guardrails/rail.py +++ b/guardrails/rail.py @@ -31,6 +31,7 @@ class Rail: prompt_schema: Optional[Schema] instructions_schema: Optional[Schema] + msg_history_schema: Optional[Schema] output_schema: Schema instructions: Optional[Instructions] prompt: Optional[Prompt] @@ -54,6 +55,7 @@ def from_pydantic( return cls( prompt_schema=None, instructions_schema=None, + msg_history_schema=None, output_schema=output_schema, instructions=cls.load_instructions(instructions, output_schema), prompt=cls.load_prompt(prompt, output_schema), @@ -77,17 +79,6 @@ def from_xml(cls, xml: ET._Element): "Change the opening element to: ." ) - # Load schema - # TODO change this to `prompt_validators` and `instructions_validators` - raw_input_schema = xml.find("input") - if raw_input_schema is None: - # No input schema, so do no input checking. - input_schema = None - else: - input_schema = cls.load_input_schema_from_xml(raw_input_schema) - prompt_schema = None - instructions_schema = None - # Load schema raw_output_schema = xml.find("output") if raw_output_schema is None: @@ -110,16 +101,23 @@ def from_xml(cls, xml: ET._Element): # Parse instructions for the LLM. These are optional but if given, # LLMs can use them to improve their output. Commonly these are # prepended to the prompt. - instructions = xml.find("instructions") - if instructions is not None: - instructions = cls.load_instructions(instructions.text, output_schema) + instructions_tag = xml.find("instructions") + if instructions_tag is None: + instructions = None + instructions_schema = None + else: + instructions = cls.load_instructions(instructions_tag.text, output_schema) + instructions_schema = cls.load_input_schema_from_xml(instructions_tag) # Load - prompt = xml.find("prompt") - if prompt is None: + prompt_tag = xml.find("prompt") + if prompt_tag is None: warnings.warn("Prompt must be provided during __call__.") + prompt = None + prompt_schema = None else: - prompt = cls.load_prompt(prompt.text, output_schema) + prompt = cls.load_prompt(prompt_tag.text, output_schema) + prompt_schema = cls.load_input_schema_from_xml(prompt_tag) # Get version version = xml.attrib["version"] @@ -128,6 +126,7 @@ def from_xml(cls, xml: ET._Element): return cls( prompt_schema=prompt_schema, instructions_schema=instructions_schema, + msg_history_schema=None, output_schema=output_schema, instructions=instructions, prompt=prompt, @@ -154,16 +153,20 @@ def from_string_validators( return cls( prompt_schema=None, instructions_schema=None, + msg_history_schema=None, output_schema=output_schema, instructions=cls.load_instructions(instructions, output_schema), prompt=cls.load_prompt(prompt, output_schema), ) @staticmethod - def load_input_schema_from_xml(root: ET._Element) -> Schema: + def load_input_schema_from_xml( + root: Optional[ET._Element] + ) -> Optional[StringSchema]: """Given the RAIL element, create a Schema object.""" - # Recast the schema as an InputSchema. - return Schema.from_xml(root) + if root is None: + return None + return StringSchema.from_xml(root) @staticmethod def load_output_schema_from_xml( diff --git a/guardrails/run.py b/guardrails/run.py index f515e363c..619ae44f0 100644 --- a/guardrails/run.py +++ b/guardrails/run.py @@ -55,6 +55,7 @@ def __init__( api: Optional[PromptCallableBase] = None, prompt_schema: Optional[StringSchema] = None, instructions_schema: Optional[StringSchema] = None, + msg_history_schema: Optional[StringSchema] = None, metadata: Optional[Dict[str, Any]] = None, output: Optional[str] = None, guard_history: Optional[GuardHistory] = None, @@ -92,6 +93,7 @@ def __init__( self.api = api self.prompt_schema = prompt_schema self.instructions_schema = instructions_schema + self.msg_history_schema = msg_history_schema self.output_schema = output_schema self.guard_state = guard_state self.num_reasks = num_reasks @@ -143,6 +145,7 @@ def __call__(self, prompt_params: Optional[Dict] = None) -> GuardHistory: api=self.api, prompt_schema=self.prompt_schema, instructions_schema=self.instructions_schema, + msg_history_schema=self.msg_history_schema, output_schema=self.output_schema, num_reasks=self.num_reasks, metadata=self.metadata, @@ -166,6 +169,7 @@ def __call__(self, prompt_params: Optional[Dict] = None) -> GuardHistory: prompt_params=prompt_params, prompt_schema=prompt_schema, instructions_schema=instructions_schema, + msg_history_schema=self.msg_history_schema, output_schema=output_schema, output=self.output if index == 0 else None, ) @@ -208,6 +212,7 @@ def step( prompt_params=prompt_params, prompt_schema=prompt_schema, instructions_schema=instructions_schema, + msg_history_schema=self.msg_history_schema, output_schema=output_schema, ): # Prepare: run pre-processing, and input validation. @@ -284,6 +289,7 @@ def prepare( api: Optional[Union[PromptCallableBase, AsyncPromptCallableBase]], prompt_schema: Optional[StringSchema], instructions_schema: Optional[StringSchema], + msg_history_schema: Optional[StringSchema], output_schema: Schema, ) -> Tuple[Optional[Instructions], Optional[Prompt], Optional[List[Dict]]]: """Prepare by running pre-processing and input validation. @@ -306,7 +312,18 @@ def prepare( prompt, instructions = None, None - # TODO figure out what to do with msg_history in terms of input validation + # validate msg_history + if msg_history_schema is not None: + validated_msg_history = msg_history_schema.validate( + guard_logs, msg_history_string(msg_history), self.metadata + ) + if validated_msg_history is None: + raise ValidatorError("Message history validation failed") + if isinstance(validated_msg_history, ReAsk): + raise ValidatorError( + f"Message history validation failed: {validated_msg_history}" + ) + msg_history = validated_msg_history elif prompt is not None: if isinstance(prompt, str): prompt = Prompt(prompt) @@ -509,7 +526,9 @@ def __init__( instructions: Optional[Union[str, Instructions]] = None, msg_history: Optional[List[Dict]] = None, api: Optional[AsyncPromptCallableBase] = None, - input_schema: Optional[Schema] = None, + prompt_schema: Optional[StringSchema] = None, + instructions_schema: Optional[StringSchema] = None, + msg_history_schema: Optional[StringSchema] = None, metadata: Optional[Dict[str, Any]] = None, output: Optional[str] = None, guard_history: Optional[GuardHistory] = None, @@ -524,7 +543,9 @@ def __init__( instructions=instructions, msg_history=msg_history, api=api, - input_schema=input_schema, + prompt_schema=prompt_schema, + instructions_schema=instructions_schema, + msg_history_schema=msg_history_schema, metadata=metadata, output=output, guard_history=guard_history, @@ -562,16 +583,20 @@ async def async_run(self, prompt_params: Optional[Dict] = None) -> GuardHistory: instructions=self.instructions, prompt=self.prompt, api=self.api, - input_schema=self.input_schema, + prompt_schema=self.prompt_schema, + instructions_schema=self.instructions_schema, + msg_history_schema=self.msg_history_schema, output_schema=self.output_schema, num_reasks=self.num_reasks, metadata=self.metadata, ): - instructions, prompt, msg_history, input_schema, output_schema = ( + instructions, prompt, msg_history, prompt_schema, instructions_schema, msg_history_schema, output_schema = ( self.instructions, self.prompt, self.msg_history, - self.input_schema, + self.prompt_schema, + self.instructions_schema, + self.msg_history_schema, self.output_schema, ) for index in range(self.num_reasks + 1): @@ -583,7 +608,9 @@ async def async_run(self, prompt_params: Optional[Dict] = None) -> GuardHistory: prompt=prompt, msg_history=msg_history, prompt_params=prompt_params, - input_schema=input_schema, + prompt_schema=prompt_schema, + instructions_schema=instructions_schema, + msg_history_schema=msg_history_schema, output_schema=output_schema, output=self.output if index == 0 else None, ) @@ -609,7 +636,9 @@ async def async_step( prompt: Optional[Prompt], msg_history: Optional[List[Dict]], prompt_params: Dict, - input_schema: Optional[Schema], + prompt_schema: Optional[StringSchema], + instructions_schema: Optional[StringSchema], + msg_history_schema: Optional[StringSchema], output_schema: Schema, output: Optional[str] = None, ): @@ -622,7 +651,9 @@ async def async_step( instructions=instructions, prompt=prompt, prompt_params=prompt_params, - input_schema=input_schema, + prompt_schema=prompt_schema, + instructions_schema=instructions_schema, + msg_history_schema=self.msg_history_schema, output_schema=output_schema, ): # Prepare: run pre-processing, and input validation. @@ -641,6 +672,7 @@ async def async_step( api, prompt_schema, instructions_schema, + msg_history_schema, output_schema, ) @@ -766,9 +798,115 @@ async def async_validate( return validated_output + async def async_prepare( + self, + guard_logs: GuardLogs, + index: int, + instructions: Optional[Instructions], + prompt: Optional[Prompt], + msg_history: Optional[List[Dict]], + prompt_params: Dict, + api: Optional[Union[PromptCallableBase, AsyncPromptCallableBase]], + prompt_schema: Optional[StringSchema], + instructions_schema: Optional[StringSchema], + msg_history_schema: Optional[StringSchema], + output_schema: Schema, + ) -> Tuple[Optional[Instructions], Optional[Prompt], Optional[List[Dict]]]: + """Prepare by running pre-processing and input validation. + + Returns: + The instructions, prompt, and message history. + """ + with start_action(action_type="prepare", index=index) as action: + if api is None: + raise ValueError("API must be provided.") + + if prompt_params is None: + prompt_params = {} + + if msg_history: + msg_history = copy.deepcopy(msg_history) + # Format any variables in the message history with the prompt params. + for msg in msg_history: + msg["content"] = msg["content"].format(**prompt_params) + + prompt, instructions = None, None + + # validate msg_history + if msg_history_schema is not None: + validated_msg_history = await msg_history_schema.async_validate( + guard_logs, msg_history_string(msg_history), self.metadata + ) + if validated_msg_history is None: + raise ValidatorError("Message history validation failed") + if isinstance(validated_msg_history, ReAsk): + raise ValidatorError( + f"Message history validation failed: {validated_msg_history}" + ) + msg_history = validated_msg_history + elif prompt is not None: + if isinstance(prompt, str): + prompt = Prompt(prompt) + + prompt = prompt.format(**prompt_params) + + # TODO(shreya): should there be any difference + # to parsing params for prompt? + if instructions is not None and isinstance(instructions, Instructions): + instructions = instructions.format(**prompt_params) + + instructions, prompt = output_schema.preprocess_prompt( + api, instructions, prompt + ) + + # validate prompt + if prompt_schema is not None: + validated_prompt = await prompt_schema.async_validate( + guard_logs, prompt.source, self.metadata + ) + if validated_prompt is None: + raise ValidatorError("Prompt validation failed") + if isinstance(validated_prompt, ReAsk): + raise ValidatorError( + f"Prompt validation failed: {validated_prompt}" + ) + prompt = Prompt(validated_prompt) + + # validate instructions + if instructions_schema is not None: + validated_instructions = await instructions_schema.async_validate( + guard_logs, instructions.source, self.metadata + ) + if validated_instructions is None: + raise ValidatorError("Instructions validation failed") + if isinstance(validated_instructions, ReAsk): + raise ValidatorError( + f"Instructions validation failed: {validated_instructions}" + ) + instructions = Instructions(validated_instructions) + else: + raise ValueError("Prompt or message history must be provided.") + + action.log( + message_type="info", + instructions=instructions, + prompt=prompt, + prompt_params=prompt_params, + validated_prompt_params=prompt_params, + ) + + return instructions, prompt, msg_history + def msg_history_source(msg_history) -> List[Dict[str, str]]: msg_history_copy = copy.deepcopy(msg_history) for msg in msg_history_copy: msg["content"] = msg["content"].source return msg_history_copy + + +def msg_history_string(msg_history) -> str: + msg_history_copy = "" + for msg in msg_history: + msg_history_copy += msg["content"].source + return msg_history_copy diff --git a/tests/integration_tests/test_run.py b/tests/integration_tests/test_run.py index bdacf2680..792ab811f 100644 --- a/tests/integration_tests/test_run.py +++ b/tests/integration_tests/test_run.py @@ -40,7 +40,8 @@ def runner_instance(is_sync: bool): prompt=PROMPT, msg_history=None, api=OpenAICallable, - input_schema=None, + prompt_schema=None, + instructions_schema=None, output_schema=OUTPUT_SCHEMA, guard_state={}, num_reasks=0, @@ -51,7 +52,8 @@ def runner_instance(is_sync: bool): prompt=PROMPT, msg_history=None, api=AsyncOpenAICallable, - input_schema=None, + prompt_schema=None, + instructions_schema=None, output_schema=OUTPUT_SCHEMA, guard_state={}, num_reasks=0, diff --git a/tests/unit_tests/test_validators.py b/tests/unit_tests/test_validators.py index b8f3614f8..163891f9a 100644 --- a/tests/unit_tests/test_validators.py +++ b/tests/unit_tests/test_validators.py @@ -620,3 +620,17 @@ class Pet(BaseModel): ) else: assert validated_output == expected_result + + +def test_xml_input_validation(): + rail_str = """ + + +This is not two words + + + +""" + guard = Guard.from_rail_string(rail_str) + with pytest.raises(ValueError): + guard.parse("") From 6a3ddaa43fa16e3f0b63a9b24cbe7232a0360438 Mon Sep 17 00:00:00 2001 From: Rafael Irgolic Date: Mon, 4 Dec 2023 11:40:03 +0000 Subject: [PATCH 04/11] raise on msg_history/prompt validation mismatch --- guardrails/run.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/guardrails/run.py b/guardrails/run.py index 619ae44f0..f1e7bb2f7 100644 --- a/guardrails/run.py +++ b/guardrails/run.py @@ -198,6 +198,7 @@ def step( prompt_params: Dict, prompt_schema: Optional[StringSchema], instructions_schema: Optional[StringSchema], + msg_history_schema: Optional[StringSchema], output_schema: Schema, output: Optional[str] = None, ): @@ -212,7 +213,7 @@ def step( prompt_params=prompt_params, prompt_schema=prompt_schema, instructions_schema=instructions_schema, - msg_history_schema=self.msg_history_schema, + msg_history_schema=msg_history_schema, output_schema=output_schema, ): # Prepare: run pre-processing, and input validation. @@ -231,6 +232,7 @@ def step( api, prompt_schema, instructions_schema, + msg_history_schema, output_schema, ) @@ -305,6 +307,11 @@ def prepare( prompt_params = {} if msg_history: + if prompt_schema is not None or instructions_schema is not None: + raise ValueError( + "Prompt and instructions validation are " + "not supported when using message history." + ) msg_history = copy.deepcopy(msg_history) # Format any variables in the message history with the prompt params. for msg in msg_history: @@ -325,6 +332,11 @@ def prepare( ) msg_history = validated_msg_history elif prompt is not None: + if msg_history_schema is not None: + raise ValueError( + "Message history validation is " + "not supported when using prompt/instructions." + ) if isinstance(prompt, str): prompt = Prompt(prompt) From 01a9ccad22146e98b94ddcb87f5549780bf5912c Mon Sep 17 00:00:00 2001 From: Rafael Irgolic Date: Mon, 4 Dec 2023 12:18:25 +0000 Subject: [PATCH 05/11] amend input validation call logging --- guardrails/run.py | 131 ++++++++++++++++++++++++++-------------------- 1 file changed, 74 insertions(+), 57 deletions(-) diff --git a/guardrails/run.py b/guardrails/run.py index 08ce7909d..1cf9fd8b4 100644 --- a/guardrails/run.py +++ b/guardrails/run.py @@ -207,25 +207,6 @@ def step( output: Optional[str] = None, ) -> Iteration: """Run a full step.""" - inputs = Inputs( - llm_api=api, - llm_output=output, - instructions=instructions, - prompt=prompt, - msg_history=msg_history, - prompt_params=prompt_params, - prompt_schema=prompt_schema, - instructions_schema=instructions_schema, - msg_history_schema=msg_history_schema, - output_schema=output_schema, - num_reasks=self.num_reasks, - metadata=self.metadata, - full_schema_reask=self.full_schema_reask, - ) - outputs = Outputs() - iteration = Iteration(inputs=inputs, outputs=outputs) - call_log.iterations.push(iteration) - try: with start_action( action_type="step", @@ -233,7 +214,9 @@ def step( instructions=instructions, prompt=prompt, prompt_params=prompt_params, - input_schema=input_schema, + prompt_schema=prompt_schema, + instructions_schema=instructions_schema, + msg_history_schema=msg_history_schema, output_schema=output_schema, ): # Prepare: run pre-processing, and input validation. @@ -243,7 +226,7 @@ def step( msg_history = None else: instructions, prompt, msg_history = self.prepare( - iteration, + call_log, index, instructions, prompt, @@ -256,9 +239,20 @@ def step( output_schema, ) - iteration.inputs.prompt = prompt - iteration.inputs.instructions = instructions - iteration.inputs.msg_history = msg_history + inputs = Inputs( + llm_api=api, + llm_output=output, + instructions=instructions, + prompt=prompt, + msg_history=msg_history, + prompt_params=prompt_params, + num_reasks=self.num_reasks, + metadata=self.metadata, + full_schema_reask=self.full_schema_reask, + ) + outputs = Outputs() + iteration = Iteration(inputs=inputs, outputs=outputs) + call_log.iterations.push(iteration) # Call: run the API. llm_response = self.call( @@ -303,7 +297,7 @@ def step( def prepare( self, - iteration: Iteration, + call_log: Call, index: int, instructions: Optional[Instructions], prompt: Optional[Prompt], @@ -342,16 +336,21 @@ def prepare( # validate msg_history if msg_history_schema is not None: + msg_str = msg_history_string(msg_history) + inputs = Inputs( + llm_output=msg_str, + ) + iteration = Iteration(inputs=inputs) + call_log.iterations.push(iteration) validated_msg_history = msg_history_schema.validate( - iteration, msg_history_string(msg_history), self.metadata + iteration, msg_str, self.metadata ) - if validated_msg_history is None: - raise ValidatorError("Message history validation failed") if isinstance(validated_msg_history, ReAsk): raise ValidatorError( f"Message history validation failed: {validated_msg_history}" ) - msg_history = validated_msg_history + if validated_msg_history != msg_str: + raise ValidatorError("Message history validation failed") elif prompt is not None: if msg_history_schema is not None: raise ValueError( @@ -374,6 +373,11 @@ def prepare( # validate prompt if prompt_schema is not None: + inputs = Inputs( + llm_output=prompt.source, + ) + iteration = Iteration(inputs=inputs) + call_log.iterations.push(iteration) validated_prompt = prompt_schema.validate( iteration, prompt.source, self.metadata ) @@ -387,6 +391,11 @@ def prepare( # validate instructions if instructions_schema is not None: + inputs = Inputs( + llm_output=instructions.source, + ) + iteration = Iteration(inputs=inputs) + call_log.iterations.push(iteration) validated_instructions = instructions_schema.validate( iteration, instructions.source, self.metadata ) @@ -686,24 +695,6 @@ async def async_step( output: Optional[str] = None, ) -> Iteration: """Run a full step.""" - inputs = Inputs( - llm_api=api, - llm_output=output, - instructions=instructions, - prompt=prompt, - msg_history=msg_history, - prompt_params=prompt_params, - prompt_schema=prompt_schema, - instructions_schema=instructions_schema, - msg_history_schema=self.msg_history_schema, - output_schema=output_schema, - num_reasks=self.num_reasks, - metadata=self.metadata, - full_schema_reask=self.full_schema_reask, - ) - outputs = Outputs() - iteration = Iteration(inputs=inputs, outputs=outputs) - call_log.iterations.push(iteration) try: with start_action( action_type="step", @@ -722,8 +713,8 @@ async def async_step( prompt = None msg_history = None else: - instructions, prompt, msg_history = self.prepare( - iteration, + instructions, prompt, msg_history = await self.async_prepare( + call_log, index, instructions, prompt, @@ -736,9 +727,20 @@ async def async_step( output_schema, ) - iteration.inputs.prompt = prompt - iteration.inputs.instructions = instructions - iteration.inputs.msg_history = msg_history + inputs = Inputs( + llm_api=api, + llm_output=output, + instructions=instructions, + prompt=prompt, + msg_history=msg_history, + prompt_params=prompt_params, + num_reasks=self.num_reasks, + metadata=self.metadata, + full_schema_reask=self.full_schema_reask, + ) + outputs = Outputs() + iteration = Iteration(inputs=inputs, outputs=outputs) + call_log.iterations.push(iteration) # Call: run the API. llm_response = await self.async_call( @@ -859,7 +861,7 @@ async def async_validate( async def async_prepare( self, - iteration: Iteration, + call_log: Call, index: int, instructions: Optional[Instructions], prompt: Optional[Prompt], @@ -893,16 +895,21 @@ async def async_prepare( # validate msg_history if msg_history_schema is not None: + msg_str = msg_history_string(msg_history) + inputs = Inputs( + llm_output=msg_str, + ) + iteration = Iteration(inputs=inputs) + call_log.iterations.push(iteration) validated_msg_history = await msg_history_schema.async_validate( - iteration, msg_history_string(msg_history), self.metadata + iteration, msg_str, self.metadata ) - if validated_msg_history is None: - raise ValidatorError("Message history validation failed") if isinstance(validated_msg_history, ReAsk): raise ValidatorError( f"Message history validation failed: {validated_msg_history}" ) - msg_history = validated_msg_history + if validated_msg_history != msg_str: + raise ValidatorError("Message history validation failed") elif prompt is not None: if isinstance(prompt, str): prompt = Prompt(prompt) @@ -920,6 +927,11 @@ async def async_prepare( # validate prompt if prompt_schema is not None: + inputs = Inputs( + llm_output=prompt.source, + ) + iteration = Iteration(inputs=inputs) + call_log.iterations.push(iteration) validated_prompt = await prompt_schema.async_validate( iteration, prompt.source, self.metadata ) @@ -933,6 +945,11 @@ async def async_prepare( # validate instructions if instructions_schema is not None: + inputs = Inputs( + llm_output=instructions.source, + ) + iteration = Iteration(inputs=inputs) + call_log.iterations.push(iteration) validated_instructions = await instructions_schema.async_validate( iteration, instructions.source, self.metadata ) From 3524f2dfc4071f202852669a3d8bc430ce9056b6 Mon Sep 17 00:00:00 2001 From: Rafael Irgolic Date: Mon, 4 Dec 2023 15:21:40 +0000 Subject: [PATCH 06/11] tests, format --- guardrails/guard.py | 21 +-- guardrails/rail.py | 12 +- guardrails/run.py | 96 +++++++----- tests/unit_tests/test_validators.py | 229 +++++++++++++++++++++++++++- 4 files changed, 298 insertions(+), 60 deletions(-) diff --git a/guardrails/guard.py b/guardrails/guard.py index 6d01b5fa4..a52658c17 100644 --- a/guardrails/guard.py +++ b/guardrails/guard.py @@ -28,9 +28,8 @@ from guardrails.prompt import Instructions, Prompt from guardrails.rail import Rail from guardrails.run import AsyncRunner, Runner -from guardrails.schema import StringSchema +from guardrails.schema import Schema, StringSchema from guardrails.utils.reask_utils import sub_reasks_with_fixed_values -from guardrails.schema import Schema from guardrails.validators import Validator logger = logging.getLogger(__name__) @@ -68,17 +67,17 @@ def __init__( self.base_model = base_model @property - def prompt_schema(self) -> Optional[Schema]: + def prompt_schema(self) -> Optional[StringSchema]: """Return the input schema.""" return self.rail.prompt_schema @property - def instructions_schema(self) -> Optional[Schema]: + def instructions_schema(self) -> Optional[StringSchema]: """Return the input schema.""" return self.rail.instructions_schema @property - def msg_history_schema(self) -> Optional[Schema]: + def msg_history_schema(self) -> Optional[StringSchema]: """Return the input schema.""" return self.rail.msg_history_schema @@ -694,9 +693,7 @@ def with_prompt_validation( validators: The validators to add to the prompt. """ if self.rail.prompt_schema: - warnings.warn( - "Overriding existing prompt validators." - ) + warnings.warn("Overriding existing prompt validators.") schema = StringSchema.from_string( validators=validators, ) @@ -713,9 +710,7 @@ def with_instructions_validation( validators: The validators to add to the instructions. """ if self.rail.instructions_schema: - warnings.warn( - "Overriding existing instructions validators." - ) + warnings.warn("Overriding existing instructions validators.") schema = StringSchema.from_string( validators=validators, ) @@ -732,9 +727,7 @@ def with_msg_history_validation( validators: The validators to add to the msg_history. """ if self.rail.msg_history_schema: - warnings.warn( - "Overriding existing msg_history validators." - ) + warnings.warn("Overriding existing msg_history validators.") schema = StringSchema.from_string( validators=validators, ) diff --git a/guardrails/rail.py b/guardrails/rail.py index 71f4d8ae4..4e32f6909 100644 --- a/guardrails/rail.py +++ b/guardrails/rail.py @@ -29,9 +29,9 @@ class Rail: 4. ``, which contains the instructions to be passed to the LLM """ - prompt_schema: Optional[Schema] - instructions_schema: Optional[Schema] - msg_history_schema: Optional[Schema] + prompt_schema: Optional[StringSchema] + instructions_schema: Optional[StringSchema] + msg_history_schema: Optional[StringSchema] output_schema: Schema instructions: Optional[Instructions] prompt: Optional[Prompt] @@ -168,10 +168,12 @@ def from_string_validators( @staticmethod def load_input_schema_from_xml( - root: Optional[ET._Element] + root: Optional[ET._Element], ) -> Optional[StringSchema]: """Given the RAIL element, create a Schema object.""" - if root is None: + if root is None or all( + tag not in root.attrib for tag in ["format", "validators"] + ): return None return StringSchema.from_xml(root) diff --git a/guardrails/run.py b/guardrails/run.py index 1cf9fd8b4..c1edbef00 100644 --- a/guardrails/run.py +++ b/guardrails/run.py @@ -19,7 +19,6 @@ sub_reasks_with_fixed_values, ) from guardrails.validator_base import ValidatorError -from guardrails.utils.reask_utils import NonParseableReAsk, ReAsk, reasks_to_dict logger = logging.getLogger(__name__) actions_logger = logging.getLogger(f"{__name__}.actions") @@ -143,7 +142,15 @@ def __call__( num_reasks=self.num_reasks, metadata=self.metadata, ): - instructions, prompt, msg_history, prompt_schema, instructions_schema, msg_history_schema, output_schema = ( + ( + instructions, + prompt, + msg_history, + prompt_schema, + instructions_schema, + msg_history_schema, + output_schema, + ) = ( self.instructions, self.prompt, self.msg_history, @@ -186,7 +193,9 @@ def __call__( prompt_params=prompt_params, include_instructions=include_instructions, ) - + # TODO decide how to handle errors + except (ValidatorError, ValueError) as e: + raise e except Exception as e: error_message = str(e) return call_log, error_message @@ -207,6 +216,19 @@ def step( output: Optional[str] = None, ) -> Iteration: """Run a full step.""" + inputs = Inputs( + llm_api=api, + llm_output=output, + instructions=instructions, + prompt=prompt, + msg_history=msg_history, + prompt_params=prompt_params, + num_reasks=self.num_reasks, + metadata=self.metadata, + full_schema_reask=self.full_schema_reask, + ) + outputs = Outputs() + iteration = Iteration(inputs=inputs, outputs=outputs) try: with start_action( action_type="step", @@ -239,19 +261,10 @@ def step( output_schema, ) - inputs = Inputs( - llm_api=api, - llm_output=output, - instructions=instructions, - prompt=prompt, - msg_history=msg_history, - prompt_params=prompt_params, - num_reasks=self.num_reasks, - metadata=self.metadata, - full_schema_reask=self.full_schema_reask, - ) - outputs = Outputs() - iteration = Iteration(inputs=inputs, outputs=outputs) + iteration.inputs.instructions = instructions + iteration.inputs.prompt = prompt + iteration.inputs.msg_history = msg_history + call_log.iterations.push(iteration) # Call: run the API. @@ -345,6 +358,7 @@ def prepare( validated_msg_history = msg_history_schema.validate( iteration, msg_str, self.metadata ) + iteration.outputs.validation_output = validated_msg_history if isinstance(validated_msg_history, ReAsk): raise ValidatorError( f"Message history validation failed: {validated_msg_history}" @@ -372,7 +386,7 @@ def prepare( ) # validate prompt - if prompt_schema is not None: + if prompt_schema is not None and prompt is not None: inputs = Inputs( llm_output=prompt.source, ) @@ -381,6 +395,7 @@ def prepare( validated_prompt = prompt_schema.validate( iteration, prompt.source, self.metadata ) + iteration.outputs.validation_output = validated_prompt if validated_prompt is None: raise ValidatorError("Prompt validation failed") if isinstance(validated_prompt, ReAsk): @@ -390,7 +405,7 @@ def prepare( prompt = Prompt(validated_prompt) # validate instructions - if instructions_schema is not None: + if instructions_schema is not None and instructions is not None: inputs = Inputs( llm_output=instructions.source, ) @@ -399,6 +414,7 @@ def prepare( validated_instructions = instructions_schema.validate( iteration, instructions.source, self.metadata ) + iteration.outputs.validation_output = validated_instructions if validated_instructions is None: raise ValidatorError("Instructions validation failed") if isinstance(validated_instructions, ReAsk): @@ -632,7 +648,15 @@ async def async_run( num_reasks=self.num_reasks, metadata=self.metadata, ): - instructions, prompt, msg_history, prompt_schema, instructions_schema, msg_history_schema, output_schema = ( + ( + instructions, + prompt, + msg_history, + prompt_schema, + instructions_schema, + msg_history_schema, + output_schema, + ) = ( self.instructions, self.prompt, self.msg_history, @@ -695,6 +719,19 @@ async def async_step( output: Optional[str] = None, ) -> Iteration: """Run a full step.""" + inputs = Inputs( + llm_api=api, + llm_output=output, + instructions=instructions, + prompt=prompt, + msg_history=msg_history, + prompt_params=prompt_params, + num_reasks=self.num_reasks, + metadata=self.metadata, + full_schema_reask=self.full_schema_reask, + ) + outputs = Outputs() + iteration = Iteration(inputs=inputs, outputs=outputs) try: with start_action( action_type="step", @@ -727,19 +764,10 @@ async def async_step( output_schema, ) - inputs = Inputs( - llm_api=api, - llm_output=output, - instructions=instructions, - prompt=prompt, - msg_history=msg_history, - prompt_params=prompt_params, - num_reasks=self.num_reasks, - metadata=self.metadata, - full_schema_reask=self.full_schema_reask, - ) - outputs = Outputs() - iteration = Iteration(inputs=inputs, outputs=outputs) + iteration.inputs.instructions = instructions + iteration.inputs.prompt = prompt + iteration.inputs.msg_history = msg_history + call_log.iterations.push(iteration) # Call: run the API. @@ -926,7 +954,7 @@ async def async_prepare( ) # validate prompt - if prompt_schema is not None: + if prompt_schema is not None and prompt is not None: inputs = Inputs( llm_output=prompt.source, ) @@ -944,7 +972,7 @@ async def async_prepare( prompt = Prompt(validated_prompt) # validate instructions - if instructions_schema is not None: + if instructions_schema is not None and instructions is not None: inputs = Inputs( llm_output=instructions.source, ) diff --git a/tests/unit_tests/test_validators.py b/tests/unit_tests/test_validators.py index c96833c34..4a8a08832 100644 --- a/tests/unit_tests/test_validators.py +++ b/tests/unit_tests/test_validators.py @@ -9,7 +9,7 @@ from guardrails import Guard from guardrails.datatypes import DataType from guardrails.schema import StringSchema -from guardrails.utils.openai_utils import OPENAI_VERSION +from guardrails.utils.openai_utils import OPENAI_VERSION, get_static_openai_create_func from guardrails.utils.reask_utils import FieldReAsk from guardrails.validator_base import ( FailResult, @@ -623,15 +623,230 @@ class Pet(BaseModel): assert validated_output == expected_result -def test_xml_input_validation(): - rail_str = """ +class Pet(BaseModel): + name: str = Field(description="a unique pet name") + + +def test_input_validation_fix(mocker): + if OPENAI_VERSION.startswith("0"): + mocker.patch("openai.ChatCompletion.create", new=mock_chat_completion) + else: + mocker.patch( + "openai.resources.chat.completions.Completions.create", + new=mock_chat_completion, + ) + + # fix returns an amended value for prompt/instructions validation, + guard = Guard.from_pydantic(output_class=Pet).with_prompt_validation( + validators=[TwoWords(on_fail="fix")] + ) + guard( + get_static_openai_create_func(), + prompt="What kind of pet should I get?", + ) + assert guard.history.first.iterations.first.outputs.validation_output == "What kind" + guard = Guard.from_pydantic(output_class=Pet).with_instructions_validation( + validators=[TwoWords(on_fail="fix")] + ) + guard( + get_static_openai_create_func(), + prompt="What kind of pet should I get and what should I name it?", + instructions="But really, what kind of pet should I get?", + ) + assert ( + guard.history.first.iterations.first.outputs.validation_output == "But really," + ) + + # but raises for msg_history validation + with pytest.raises(ValidatorError): + guard = Guard.from_pydantic(output_class=Pet).with_msg_history_validation( + validators=[TwoWords(on_fail="fix")] + ) + guard( + get_static_openai_create_func(), + msg_history=[ + { + "role": "user", + "content": "What kind of pet should I get?", + } + ], + ) + + # rail prompt validation + guard = Guard.from_rail_string( + f""" - + This is not two words -""" - guard = Guard.from_rail_string(rail_str) + +""" + ) + guard( + get_static_openai_create_func(), + ) + assert guard.history.first.iterations.first.outputs.validation_output == "This is" + + # rail instructions validation + guard = Guard.from_rail_string( + f""" + + +This is not two words + + +This also is not two words + + + + +""" + ) + guard( + get_static_openai_create_func(), + ) + assert guard.history.first.iterations.first.outputs.validation_output == "This also" + + +@pytest.mark.parametrize( + "on_fail", + [ + "reask", + "filter", + "refrain", + "exception", + ], +) +def test_input_validation_fail(mocker, on_fail): + if OPENAI_VERSION.startswith("0"): + mocker.patch("openai.ChatCompletion.create", new=mock_chat_completion) + else: + mocker.patch( + "openai.resources.chat.completions.Completions.create", + new=mock_chat_completion, + ) + + # with_prompt_validation + with pytest.raises(ValidatorError): + guard = Guard.from_pydantic(output_class=Pet).with_prompt_validation( + validators=[TwoWords(on_fail=on_fail)] + ) + guard( + get_static_openai_create_func(), + prompt="What kind of pet should I get?", + ) + # with_instructions_validation + with pytest.raises(ValidatorError): + guard = Guard.from_pydantic(output_class=Pet).with_instructions_validation( + validators=[TwoWords(on_fail=on_fail)] + ) + guard( + get_static_openai_create_func(), + prompt="What kind of pet should I get and what should I name it?", + instructions="What kind of pet should I get?", + ) + # with_msg_history_validation + with pytest.raises(ValidatorError): + guard = Guard.from_pydantic(output_class=Pet).with_msg_history_validation( + validators=[TwoWords(on_fail=on_fail)] + ) + guard( + get_static_openai_create_func(), + msg_history=[ + { + "role": "user", + "content": "What kind of pet should I get?", + } + ], + ) + # rail prompt validation + guard = Guard.from_rail_string( + f""" + + +This is not two words + + + + +""" + ) + with pytest.raises(ValidatorError): + guard( + get_static_openai_create_func(), + ) + # rail instructions validation + guard = Guard.from_rail_string( + f""" + + +This is not two words + + +This also is not two words + + + + +""" + ) with pytest.raises(ValidatorError): - guard.parse("") + guard( + get_static_openai_create_func(), + ) + + +def test_input_validation_mismatch_raise(): + # prompt validation, msg_history argument + guard = Guard.from_pydantic(output_class=Pet).with_prompt_validation( + validators=[TwoWords(on_fail="fix")] + ) + with pytest.raises(ValueError): + guard( + get_static_openai_create_func(), + msg_history=[ + { + "role": "user", + "content": "What kind of pet should I get?", + } + ], + ) + + # instructions validation, msg_history argument + guard = Guard.from_pydantic(output_class=Pet).with_instructions_validation( + validators=[TwoWords(on_fail="fix")] + ) + with pytest.raises(ValueError): + guard( + get_static_openai_create_func(), + msg_history=[ + { + "role": "user", + "content": "What kind of pet should I get?", + } + ], + ) + + # msg_history validation, prompt argument + guard = Guard.from_pydantic(output_class=Pet).with_msg_history_validation( + validators=[TwoWords(on_fail="fix")] + ) + with pytest.raises(ValueError): + guard( + get_static_openai_create_func(), + prompt="What kind of pet should I get?", + ) From a0a693e63b3726aa47bc193776679dc64bd1cd55 Mon Sep 17 00:00:00 2001 From: Rafael Irgolic Date: Mon, 4 Dec 2023 19:25:21 +0000 Subject: [PATCH 07/11] add async tests --- guardrails/guard.py | 1 - guardrails/run.py | 18 +-- tests/unit_tests/test_validators.py | 203 +++++++++++++++++++++++++--- 3 files changed, 190 insertions(+), 32 deletions(-) diff --git a/guardrails/guard.py b/guardrails/guard.py index a52658c17..c36839535 100644 --- a/guardrails/guard.py +++ b/guardrails/guard.py @@ -29,7 +29,6 @@ from guardrails.rail import Rail from guardrails.run import AsyncRunner, Runner from guardrails.schema import Schema, StringSchema -from guardrails.utils.reask_utils import sub_reasks_with_fixed_values from guardrails.validators import Validator logger = logging.getLogger(__name__) diff --git a/guardrails/run.py b/guardrails/run.py index c1edbef00..7a47fdadb 100644 --- a/guardrails/run.py +++ b/guardrails/run.py @@ -11,13 +11,7 @@ from guardrails.prompt import Instructions, Prompt from guardrails.schema import Schema, StringSchema from guardrails.utils.llm_response import LLMResponse -from guardrails.utils.reask_utils import ( - FieldReAsk, - NonParseableReAsk, - ReAsk, - reasks_to_dict, - sub_reasks_with_fixed_values, -) +from guardrails.utils.reask_utils import NonParseableReAsk, ReAsk, reasks_to_dict from guardrails.validator_base import ValidatorError logger = logging.getLogger(__name__) @@ -361,7 +355,8 @@ def prepare( iteration.outputs.validation_output = validated_msg_history if isinstance(validated_msg_history, ReAsk): raise ValidatorError( - f"Message history validation failed: {validated_msg_history}" + f"Message history validation failed: " + f"{validated_msg_history}" ) if validated_msg_history != msg_str: raise ValidatorError("Message history validation failed") @@ -698,6 +693,8 @@ async def async_run( output_schema, prompt_params=prompt_params, ) + except (ValidatorError, ValueError) as e: + raise e except Exception as e: error_message = str(e) @@ -934,7 +931,8 @@ async def async_prepare( ) if isinstance(validated_msg_history, ReAsk): raise ValidatorError( - f"Message history validation failed: {validated_msg_history}" + f"Message history validation failed: " + f"{validated_msg_history}" ) if validated_msg_history != msg_str: raise ValidatorError("Message history validation failed") @@ -963,6 +961,7 @@ async def async_prepare( validated_prompt = await prompt_schema.async_validate( iteration, prompt.source, self.metadata ) + iteration.outputs.validation_output = validated_prompt if validated_prompt is None: raise ValidatorError("Prompt validation failed") if isinstance(validated_prompt, ReAsk): @@ -981,6 +980,7 @@ async def async_prepare( validated_instructions = await instructions_schema.async_validate( iteration, instructions.source, self.metadata ) + iteration.outputs.validation_output = validated_instructions if validated_instructions is None: raise ValidatorError("Instructions validation failed") if isinstance(validated_instructions, ReAsk): diff --git a/tests/unit_tests/test_validators.py b/tests/unit_tests/test_validators.py index 4a8a08832..6d0429ffa 100644 --- a/tests/unit_tests/test_validators.py +++ b/tests/unit_tests/test_validators.py @@ -9,7 +9,11 @@ from guardrails import Guard from guardrails.datatypes import DataType from guardrails.schema import StringSchema -from guardrails.utils.openai_utils import OPENAI_VERSION, get_static_openai_create_func +from guardrails.utils.openai_utils import ( + OPENAI_VERSION, + get_static_openai_acreate_func, + get_static_openai_create_func, +) from guardrails.utils.reask_utils import FieldReAsk from guardrails.validator_base import ( FailResult, @@ -627,15 +631,7 @@ class Pet(BaseModel): name: str = Field(description="a unique pet name") -def test_input_validation_fix(mocker): - if OPENAI_VERSION.startswith("0"): - mocker.patch("openai.ChatCompletion.create", new=mock_chat_completion) - else: - mocker.patch( - "openai.resources.chat.completions.Completions.create", - new=mock_chat_completion, - ) - +def test_input_validation_fix(): # fix returns an amended value for prompt/instructions validation, guard = Guard.from_pydantic(output_class=Pet).with_prompt_validation( validators=[TwoWords(on_fail="fix")] @@ -674,7 +670,7 @@ def test_input_validation_fix(mocker): # rail prompt validation guard = Guard.from_rail_string( - f""" + """ This is not two words @@ -716,6 +712,89 @@ def test_input_validation_fix(mocker): assert guard.history.first.iterations.first.outputs.validation_output == "This also" +@pytest.mark.asyncio +@pytest.mark.skipif(not OPENAI_VERSION.startswith("0"), reason="Not supported in v1") +async def test_async_input_validation_fix(): + # fix returns an amended value for prompt/instructions validation, + guard = Guard.from_pydantic(output_class=Pet).with_prompt_validation( + validators=[TwoWords(on_fail="fix")] + ) + await guard( + get_static_openai_acreate_func(), + prompt="What kind of pet should I get?", + ) + assert guard.history.first.iterations.first.outputs.validation_output == "What kind" + guard = Guard.from_pydantic(output_class=Pet).with_instructions_validation( + validators=[TwoWords(on_fail="fix")] + ) + await guard( + get_static_openai_acreate_func(), + prompt="What kind of pet should I get and what should I name it?", + instructions="But really, what kind of pet should I get?", + ) + assert ( + guard.history.first.iterations.first.outputs.validation_output == "But really," + ) + + # but raises for msg_history validation + with pytest.raises(ValidatorError): + guard = Guard.from_pydantic(output_class=Pet).with_msg_history_validation( + validators=[TwoWords(on_fail="fix")] + ) + await guard( + get_static_openai_acreate_func(), + msg_history=[ + { + "role": "user", + "content": "What kind of pet should I get?", + } + ], + ) + + # rail prompt validation + guard = Guard.from_rail_string( + """ + + +This is not two words + + + + +""" + ) + await guard( + get_static_openai_acreate_func(), + ) + assert guard.history.first.iterations.first.outputs.validation_output == "This is" + + # rail instructions validation + guard = Guard.from_rail_string( + """ + + +This is not two words + + +This also is not two words + + + + +""" + ) + await guard( + get_static_openai_acreate_func(), + ) + assert guard.history.first.iterations.first.outputs.validation_output == "This also" + + @pytest.mark.parametrize( "on_fail", [ @@ -725,15 +804,7 @@ def test_input_validation_fix(mocker): "exception", ], ) -def test_input_validation_fail(mocker, on_fail): - if OPENAI_VERSION.startswith("0"): - mocker.patch("openai.ChatCompletion.create", new=mock_chat_completion) - else: - mocker.patch( - "openai.resources.chat.completions.Completions.create", - new=mock_chat_completion, - ) - +def test_input_validation_fail(on_fail): # with_prompt_validation with pytest.raises(ValidatorError): guard = Guard.from_pydantic(output_class=Pet).with_prompt_validation( @@ -771,7 +842,7 @@ def test_input_validation_fail(mocker, on_fail): guard = Guard.from_rail_string( f""" - @@ -810,6 +881,94 @@ def test_input_validation_fail(mocker, on_fail): ) +@pytest.mark.parametrize( + "on_fail", + [ + "reask", + "filter", + "refrain", + "exception", + ], +) +@pytest.mark.asyncio +@pytest.mark.skipif(not OPENAI_VERSION.startswith("0"), reason="Not supported in v1") +async def test_input_validation_fail_async(mocker, on_fail): + # with_prompt_validation + with pytest.raises(ValidatorError): + guard = Guard.from_pydantic(output_class=Pet).with_prompt_validation( + validators=[TwoWords(on_fail=on_fail)] + ) + await guard( + get_static_openai_acreate_func(), + prompt="What kind of pet should I get?", + ) + # with_instructions_validation + with pytest.raises(ValidatorError): + guard = Guard.from_pydantic(output_class=Pet).with_instructions_validation( + validators=[TwoWords(on_fail=on_fail)] + ) + await guard( + get_static_openai_acreate_func(), + prompt="What kind of pet should I get and what should I name it?", + instructions="What kind of pet should I get?", + ) + # with_msg_history_validation + with pytest.raises(ValidatorError): + guard = Guard.from_pydantic(output_class=Pet).with_msg_history_validation( + validators=[TwoWords(on_fail=on_fail)] + ) + await guard( + get_static_openai_acreate_func(), + msg_history=[ + { + "role": "user", + "content": "What kind of pet should I get?", + } + ], + ) + # rail prompt validation + guard = Guard.from_rail_string( + f""" + + +This is not two words + + + + +""" + ) + with pytest.raises(ValidatorError): + await guard( + get_static_openai_acreate_func(), + ) + # rail instructions validation + guard = Guard.from_rail_string( + f""" + + +This is not two words + + +This also is not two words + + + + +""" + ) + with pytest.raises(ValidatorError): + await guard( + get_static_openai_acreate_func(), + ) + + def test_input_validation_mismatch_raise(): # prompt validation, msg_history argument guard = Guard.from_pydantic(output_class=Pet).with_prompt_validation( From 28355fade8db1c1b45a2deb9f06f1d793add35bc Mon Sep 17 00:00:00 2001 From: Rafael Irgolic Date: Tue, 5 Dec 2023 14:52:24 +0000 Subject: [PATCH 08/11] always push iteration to stack first --- guardrails/run.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/guardrails/run.py b/guardrails/run.py index 11c446cfc..6322457e6 100644 --- a/guardrails/run.py +++ b/guardrails/run.py @@ -222,6 +222,7 @@ def step( outputs = Outputs() iteration = Iteration(inputs=inputs, outputs=outputs) set_scope(str(id(iteration))) + call_log.iterations.push(iteration) try: with start_action( @@ -259,8 +260,6 @@ def step( iteration.inputs.prompt = prompt iteration.inputs.msg_history = msg_history - call_log.iterations.push(iteration) - # Call: run the API. llm_response = self.call( index, instructions, prompt, msg_history, api, output @@ -348,7 +347,7 @@ def prepare( llm_output=msg_str, ) iteration = Iteration(inputs=inputs) - call_log.iterations.push(iteration) + call_log.iterations.insert(0, iteration) validated_msg_history = msg_history_schema.validate( iteration, msg_str, self.metadata ) @@ -386,7 +385,7 @@ def prepare( llm_output=prompt.source, ) iteration = Iteration(inputs=inputs) - call_log.iterations.push(iteration) + call_log.iterations.insert(0, iteration) validated_prompt = prompt_schema.validate( iteration, prompt.source, self.metadata ) @@ -405,7 +404,7 @@ def prepare( llm_output=instructions.source, ) iteration = Iteration(inputs=inputs) - call_log.iterations.push(iteration) + call_log.iterations.insert(0, iteration) validated_instructions = instructions_schema.validate( iteration, instructions.source, self.metadata ) @@ -729,6 +728,8 @@ async def async_step( ) outputs = Outputs() iteration = Iteration(inputs=inputs, outputs=outputs) + call_log.iterations.push(iteration) + try: with start_action( action_type="step", @@ -765,8 +766,6 @@ async def async_step( iteration.inputs.prompt = prompt iteration.inputs.msg_history = msg_history - call_log.iterations.push(iteration) - # Call: run the API. llm_response = await self.async_call( index, instructions, prompt, msg_history, api, output @@ -925,7 +924,7 @@ async def async_prepare( llm_output=msg_str, ) iteration = Iteration(inputs=inputs) - call_log.iterations.push(iteration) + call_log.iterations.insert(0, iteration) validated_msg_history = await msg_history_schema.async_validate( iteration, msg_str, self.metadata ) @@ -957,7 +956,7 @@ async def async_prepare( llm_output=prompt.source, ) iteration = Iteration(inputs=inputs) - call_log.iterations.push(iteration) + call_log.iterations.insert(0, iteration) validated_prompt = await prompt_schema.async_validate( iteration, prompt.source, self.metadata ) @@ -976,7 +975,7 @@ async def async_prepare( llm_output=instructions.source, ) iteration = Iteration(inputs=inputs) - call_log.iterations.push(iteration) + call_log.iterations.insert(0, iteration) validated_instructions = await instructions_schema.async_validate( iteration, instructions.source, self.metadata ) From b437d8587d6c5c748233a14602940091f150ec81 Mon Sep 17 00:00:00 2001 From: Rafael Irgolic Date: Tue, 5 Dec 2023 15:43:15 +0000 Subject: [PATCH 09/11] wrap user facing exceptions --- guardrails/classes/history/call.py | 7 + guardrails/classes/history/iteration.py | 5 + guardrails/classes/history/outputs.py | 3 + guardrails/run.py | 36 ++-- guardrails/utils/exception_utils.py | 9 + tests/integration_tests/test_run.py | 4 + tests/unit_tests/test_validators.py | 216 ++++++++++++------------ 7 files changed, 156 insertions(+), 124 deletions(-) create mode 100644 guardrails/utils/exception_utils.py diff --git a/guardrails/classes/history/call.py b/guardrails/classes/history/call.py index 0e75e613d..b60d337cf 100644 --- a/guardrails/classes/history/call.py +++ b/guardrails/classes/history/call.py @@ -281,6 +281,13 @@ def error(self) -> Optional[str]: return None return self.iterations.last.error # type: ignore + @property + def exception(self) -> Optional[Exception]: + """The exception that interrupted the run.""" + if self.iterations.empty(): + return None + return self.iterations.last.exception # type: ignore + @property def failed_validations(self) -> Stack[ValidatorLogs]: """The validator logs for any validations that failed during the diff --git a/guardrails/classes/history/iteration.py b/guardrails/classes/history/iteration.py index 3d29bfe71..5aa5a346d 100644 --- a/guardrails/classes/history/iteration.py +++ b/guardrails/classes/history/iteration.py @@ -111,6 +111,11 @@ def error(self) -> Optional[str]: this iteration.""" return self.outputs.error + @property + def exception(self) -> Optional[Exception]: + """The exception that interrupted this iteration.""" + return self.outputs.exception + @property def failed_validations(self) -> List[ValidatorLogs]: """The validator logs for any validations that failed during this diff --git a/guardrails/classes/history/outputs.py b/guardrails/classes/history/outputs.py index 8ca033ed8..56890cc99 100644 --- a/guardrails/classes/history/outputs.py +++ b/guardrails/classes/history/outputs.py @@ -43,6 +43,9 @@ class Outputs(ArbitraryModel): "that raised and interrupted the process.", default=None, ) + exception: Optional[Exception] = Field( + description="The exception that interrupted the process.", default=None + ) def _all_empty(self) -> bool: return ( diff --git a/guardrails/run.py b/guardrails/run.py index 6322457e6..f3b4fd08c 100644 --- a/guardrails/run.py +++ b/guardrails/run.py @@ -10,6 +10,7 @@ from guardrails.logger import logger, set_scope from guardrails.prompt import Instructions, Prompt from guardrails.schema import Schema, StringSchema +from guardrails.utils.exception_utils import UserFacingException from guardrails.utils.llm_response import LLMResponse from guardrails.utils.reask_utils import NonParseableReAsk, ReAsk, reasks_to_dict from guardrails.validator_base import ValidatorError @@ -185,9 +186,8 @@ def __call__( prompt_params=prompt_params, include_instructions=include_instructions, ) - # TODO decide how to handle errors - except (ValidatorError, ValueError) as e: - raise e + except UserFacingException as e: + raise e.original_exception except Exception as e: error_message = str(e) return call_log, error_message @@ -273,6 +273,7 @@ def step( index, raw_output, output_schema ) if parsing_error: + iteration.outputs.exception = parsing_error iteration.outputs.error = str(parsing_error) iteration.outputs.parsed_output = parsed_output @@ -298,6 +299,7 @@ def step( except Exception as e: error_message = str(e) iteration.outputs.error = error_message + iteration.outputs.exception = e raise e return iteration @@ -322,16 +324,18 @@ def prepare( """ with start_action(action_type="prepare", index=index) as action: if api is None: - raise ValueError("API must be provided.") + raise UserFacingException(ValueError("API must be provided.")) if prompt_params is None: prompt_params = {} if msg_history: if prompt_schema is not None or instructions_schema is not None: - raise ValueError( - "Prompt and instructions validation are " - "not supported when using message history." + raise UserFacingException( + ValueError( + "Prompt and instructions validation are " + "not supported when using message history." + ) ) msg_history = copy.deepcopy(msg_history) # Format any variables in the message history with the prompt params. @@ -361,9 +365,11 @@ def prepare( raise ValidatorError("Message history validation failed") elif prompt is not None: if msg_history_schema is not None: - raise ValueError( - "Message history validation is " - "not supported when using prompt/instructions." + raise UserFacingException( + ValueError( + "Message history validation is " + "not supported when using prompt/instructions." + ) ) if isinstance(prompt, str): prompt = Prompt(prompt) @@ -417,7 +423,9 @@ def prepare( ) instructions = Instructions(validated_instructions) else: - raise ValueError("Prompt or message history must be provided.") + raise UserFacingException( + ValueError("Prompt or message history must be provided.") + ) action.log( message_type="info", @@ -692,8 +700,8 @@ async def async_run( output_schema, prompt_params=prompt_params, ) - except (ValidatorError, ValueError) as e: - raise e + except UserFacingException as e: + raise e.original_exception except Exception as e: error_message = str(e) @@ -777,6 +785,7 @@ async def async_step( # Parse: parse the output. parsed_output, parsing_error = self.parse(index, output, output_schema) if parsing_error: + iteration.outputs.exception = parsing_error iteration.outputs.error = str(parsing_error) iteration.outputs.parsed_output = parsed_output @@ -801,6 +810,7 @@ async def async_step( except Exception as e: error_message = str(e) iteration.outputs.error = error_message + iteration.outputs.exception = e raise e return iteration diff --git a/guardrails/utils/exception_utils.py b/guardrails/utils/exception_utils.py new file mode 100644 index 000000000..c11c8b8ac --- /dev/null +++ b/guardrails/utils/exception_utils.py @@ -0,0 +1,9 @@ +class UserFacingException(Exception): + """Wraps an exception to denote it as user-facing. + + It will be unwrapped in runner. + """ + + def __init__(self, original_exception: Exception): + super().__init__() + self.original_exception = original_exception diff --git a/tests/integration_tests/test_run.py b/tests/integration_tests/test_run.py index cf1e302b1..d24acd358 100644 --- a/tests/integration_tests/test_run.py +++ b/tests/integration_tests/test_run.py @@ -136,6 +136,8 @@ async def test_sync_async_step_equivalence(mocker): None, {}, None, + None, + None, OUTPUT_SCHEMA, call_log, OUTPUT, @@ -150,6 +152,8 @@ async def test_sync_async_step_equivalence(mocker): None, {}, None, + None, + None, OUTPUT_SCHEMA, call_log, OUTPUT, diff --git a/tests/unit_tests/test_validators.py b/tests/unit_tests/test_validators.py index 6d0429ffa..918cea601 100644 --- a/tests/unit_tests/test_validators.py +++ b/tests/unit_tests/test_validators.py @@ -617,14 +617,7 @@ class Pet(BaseModel): elif isinstance(expected_result, FieldReAsk): assert guard.history.first.iterations.first.reasks[0] == expected_result else: - validated_output = guard.parse(output, num_reasks=0) - if isinstance(expected_result, FieldReAsk): - assert ( - guard.guard_state.all_histories[0].history[0].reasks[0] - == expected_result - ) - else: - assert validated_output == expected_result + assert response.validated_output == expected_result class Pet(BaseModel): @@ -654,19 +647,19 @@ def test_input_validation_fix(): ) # but raises for msg_history validation - with pytest.raises(ValidatorError): - guard = Guard.from_pydantic(output_class=Pet).with_msg_history_validation( - validators=[TwoWords(on_fail="fix")] - ) - guard( - get_static_openai_create_func(), - msg_history=[ - { - "role": "user", - "content": "What kind of pet should I get?", - } - ], - ) + guard = Guard.from_pydantic(output_class=Pet).with_msg_history_validation( + validators=[TwoWords(on_fail="fix")] + ) + guard( + get_static_openai_create_func(), + msg_history=[ + { + "role": "user", + "content": "What kind of pet should I get?", + } + ], + ) + assert isinstance(guard.history.first.exception, ValidatorError) # rail prompt validation guard = Guard.from_rail_string( @@ -737,19 +730,19 @@ async def test_async_input_validation_fix(): ) # but raises for msg_history validation - with pytest.raises(ValidatorError): - guard = Guard.from_pydantic(output_class=Pet).with_msg_history_validation( - validators=[TwoWords(on_fail="fix")] - ) - await guard( - get_static_openai_acreate_func(), - msg_history=[ - { - "role": "user", - "content": "What kind of pet should I get?", - } - ], - ) + guard = Guard.from_pydantic(output_class=Pet).with_msg_history_validation( + validators=[TwoWords(on_fail="fix")] + ) + await guard( + get_static_openai_acreate_func(), + msg_history=[ + { + "role": "user", + "content": "What kind of pet should I get?", + } + ], + ) + assert isinstance(guard.history.first.exception, ValidatorError) # rail prompt validation guard = Guard.from_rail_string( @@ -806,38 +799,38 @@ async def test_async_input_validation_fix(): ) def test_input_validation_fail(on_fail): # with_prompt_validation - with pytest.raises(ValidatorError): - guard = Guard.from_pydantic(output_class=Pet).with_prompt_validation( - validators=[TwoWords(on_fail=on_fail)] - ) - guard( - get_static_openai_create_func(), - prompt="What kind of pet should I get?", - ) + guard = Guard.from_pydantic(output_class=Pet).with_prompt_validation( + validators=[TwoWords(on_fail=on_fail)] + ) + guard( + get_static_openai_create_func(), + prompt="What kind of pet should I get?", + ) + assert isinstance(guard.history.last.exception, ValidatorError) # with_instructions_validation - with pytest.raises(ValidatorError): - guard = Guard.from_pydantic(output_class=Pet).with_instructions_validation( - validators=[TwoWords(on_fail=on_fail)] - ) - guard( - get_static_openai_create_func(), - prompt="What kind of pet should I get and what should I name it?", - instructions="What kind of pet should I get?", - ) + guard = Guard.from_pydantic(output_class=Pet).with_instructions_validation( + validators=[TwoWords(on_fail=on_fail)] + ) + guard( + get_static_openai_create_func(), + prompt="What kind of pet should I get and what should I name it?", + instructions="What kind of pet should I get?", + ) + assert isinstance(guard.history.last.exception, ValidatorError) # with_msg_history_validation - with pytest.raises(ValidatorError): - guard = Guard.from_pydantic(output_class=Pet).with_msg_history_validation( - validators=[TwoWords(on_fail=on_fail)] - ) - guard( - get_static_openai_create_func(), - msg_history=[ - { - "role": "user", - "content": "What kind of pet should I get?", - } - ], - ) + guard = Guard.from_pydantic(output_class=Pet).with_msg_history_validation( + validators=[TwoWords(on_fail=on_fail)] + ) + guard( + get_static_openai_create_func(), + msg_history=[ + { + "role": "user", + "content": "What kind of pet should I get?", + } + ], + ) + assert isinstance(guard.history.last.exception, ValidatorError) # rail prompt validation guard = Guard.from_rail_string( f""" @@ -853,10 +846,10 @@ def test_input_validation_fail(on_fail): """ ) - with pytest.raises(ValidatorError): - guard( - get_static_openai_create_func(), - ) + guard( + get_static_openai_create_func(), + ) + assert isinstance(guard.history.last.exception, ValidatorError) # rail instructions validation guard = Guard.from_rail_string( f""" @@ -875,10 +868,10 @@ def test_input_validation_fail(on_fail): """ ) - with pytest.raises(ValidatorError): - guard( - get_static_openai_create_func(), - ) + guard( + get_static_openai_create_func(), + ) + assert isinstance(guard.history.last.exception, ValidatorError) @pytest.mark.parametrize( @@ -892,40 +885,41 @@ def test_input_validation_fail(on_fail): ) @pytest.mark.asyncio @pytest.mark.skipif(not OPENAI_VERSION.startswith("0"), reason="Not supported in v1") -async def test_input_validation_fail_async(mocker, on_fail): +async def test_input_validation_fail_async(on_fail): # with_prompt_validation - with pytest.raises(ValidatorError): - guard = Guard.from_pydantic(output_class=Pet).with_prompt_validation( - validators=[TwoWords(on_fail=on_fail)] - ) - await guard( - get_static_openai_acreate_func(), - prompt="What kind of pet should I get?", - ) + guard = Guard.from_pydantic(output_class=Pet).with_prompt_validation( + validators=[TwoWords(on_fail=on_fail)] + ) + await guard( + get_static_openai_acreate_func(), + prompt="What kind of pet should I get?", + ) + assert isinstance(guard.history.last.exception, ValidatorError) + # with_instructions_validation - with pytest.raises(ValidatorError): - guard = Guard.from_pydantic(output_class=Pet).with_instructions_validation( - validators=[TwoWords(on_fail=on_fail)] - ) - await guard( - get_static_openai_acreate_func(), - prompt="What kind of pet should I get and what should I name it?", - instructions="What kind of pet should I get?", - ) + guard = Guard.from_pydantic(output_class=Pet).with_instructions_validation( + validators=[TwoWords(on_fail=on_fail)] + ) + await guard( + get_static_openai_acreate_func(), + prompt="What kind of pet should I get and what should I name it?", + instructions="What kind of pet should I get?", + ) + assert isinstance(guard.history.last.exception, ValidatorError) # with_msg_history_validation - with pytest.raises(ValidatorError): - guard = Guard.from_pydantic(output_class=Pet).with_msg_history_validation( - validators=[TwoWords(on_fail=on_fail)] - ) - await guard( - get_static_openai_acreate_func(), - msg_history=[ - { - "role": "user", - "content": "What kind of pet should I get?", - } - ], - ) + guard = Guard.from_pydantic(output_class=Pet).with_msg_history_validation( + validators=[TwoWords(on_fail=on_fail)] + ) + await guard( + get_static_openai_acreate_func(), + msg_history=[ + { + "role": "user", + "content": "What kind of pet should I get?", + } + ], + ) + assert isinstance(guard.history.last.exception, ValidatorError) # rail prompt validation guard = Guard.from_rail_string( f""" @@ -941,10 +935,10 @@ async def test_input_validation_fail_async(mocker, on_fail): """ ) - with pytest.raises(ValidatorError): - await guard( - get_static_openai_acreate_func(), - ) + await guard( + get_static_openai_acreate_func(), + ) + assert isinstance(guard.history.last.exception, ValidatorError) # rail instructions validation guard = Guard.from_rail_string( f""" @@ -963,10 +957,10 @@ async def test_input_validation_fail_async(mocker, on_fail): """ ) - with pytest.raises(ValidatorError): - await guard( - get_static_openai_acreate_func(), - ) + await guard( + get_static_openai_acreate_func(), + ) + assert isinstance(guard.history.last.exception, ValidatorError) def test_input_validation_mismatch_raise(): From b5d9f6eccb28aad96e365823562fc765d382110e Mon Sep 17 00:00:00 2001 From: Rafael Irgolic Date: Thu, 7 Dec 2023 14:02:52 +0000 Subject: [PATCH 10/11] run: Split long control flow into submethods --- guardrails/run.py | 210 +++++++++++++++++++++++++++++----------------- 1 file changed, 135 insertions(+), 75 deletions(-) diff --git a/guardrails/run.py b/guardrails/run.py index f3b4fd08c..5ffd70c17 100644 --- a/guardrails/run.py +++ b/guardrails/run.py @@ -303,6 +303,128 @@ def step( raise e return iteration + def validate_msg_history( + self, + call_log: Call, + msg_history: List[Dict], + msg_history_schema: StringSchema, + ): + msg_str = msg_history_string(msg_history) + inputs = Inputs( + llm_output=msg_str, + ) + iteration = Iteration(inputs=inputs) + call_log.iterations.insert(0, iteration) + validated_msg_history = msg_history_schema.validate( + iteration, msg_str, self.metadata + ) + iteration.outputs.validation_output = validated_msg_history + if isinstance(validated_msg_history, ReAsk): + raise ValidatorError( + f"Message history validation failed: " f"{validated_msg_history}" + ) + if validated_msg_history != msg_str: + raise ValidatorError("Message history validation failed") + + def prepare_msg_history( + self, + call_log: Call, + msg_history: List[Dict], + prompt_params: Dict, + msg_history_schema: Optional[StringSchema], + ): + msg_history = copy.deepcopy(msg_history) + # Format any variables in the message history with the prompt params. + for msg in msg_history: + msg["content"] = msg["content"].format(**prompt_params) + + # validate msg_history + if msg_history_schema is not None: + self.validate_msg_history(call_log, msg_history, msg_history_schema) + + return msg_history + + def validate_prompt( + self, + call_log: Call, + prompt_schema: StringSchema, + prompt: Prompt, + ): + inputs = Inputs( + llm_output=prompt.source, + ) + iteration = Iteration(inputs=inputs) + call_log.iterations.insert(0, iteration) + validated_prompt = prompt_schema.validate( + iteration, prompt.source, self.metadata + ) + iteration.outputs.validation_output = validated_prompt + if validated_prompt is None: + raise ValidatorError("Prompt validation failed") + if isinstance(validated_prompt, ReAsk): + raise ValidatorError(f"Prompt validation failed: {validated_prompt}") + return Prompt(validated_prompt) + + def validate_instructions( + self, + call_log: Call, + instructions_schema: StringSchema, + instructions: Instructions, + ): + inputs = Inputs( + llm_output=instructions.source, + ) + iteration = Iteration(inputs=inputs) + call_log.iterations.insert(0, iteration) + validated_instructions = instructions_schema.validate( + iteration, instructions.source, self.metadata + ) + iteration.outputs.validation_output = validated_instructions + if validated_instructions is None: + raise ValidatorError("Instructions validation failed") + if isinstance(validated_instructions, ReAsk): + raise ValidatorError( + f"Instructions validation failed: {validated_instructions}" + ) + return Instructions(validated_instructions) + + def prepare_prompt( + self, + call_log: Call, + instructions: Optional[Instructions], + prompt: Prompt, + prompt_params: Dict, + api: Union[PromptCallableBase, AsyncPromptCallableBase], + prompt_schema: Optional[StringSchema], + instructions_schema: Optional[StringSchema], + output_schema: Schema, + ): + if isinstance(prompt, str): + prompt = Prompt(prompt) + + prompt = prompt.format(**prompt_params) + + # TODO(shreya): should there be any difference + # to parsing params for prompt? + if instructions is not None and isinstance(instructions, Instructions): + instructions = instructions.format(**prompt_params) + + instructions, prompt = output_schema.preprocess_prompt( + api, instructions, prompt + ) + + # validate prompt + if prompt_schema is not None and prompt is not None: + prompt = self.validate_prompt(call_log, prompt_schema, prompt) + + # validate instructions + if instructions_schema is not None and instructions is not None: + instructions = self.validate_instructions( + call_log, instructions_schema, instructions + ) + + return instructions, prompt + def prepare( self, call_log: Call, @@ -337,32 +459,10 @@ def prepare( "not supported when using message history." ) ) - msg_history = copy.deepcopy(msg_history) - # Format any variables in the message history with the prompt params. - for msg in msg_history: - msg["content"] = msg["content"].format(**prompt_params) - prompt, instructions = None, None - - # validate msg_history - if msg_history_schema is not None: - msg_str = msg_history_string(msg_history) - inputs = Inputs( - llm_output=msg_str, - ) - iteration = Iteration(inputs=inputs) - call_log.iterations.insert(0, iteration) - validated_msg_history = msg_history_schema.validate( - iteration, msg_str, self.metadata - ) - iteration.outputs.validation_output = validated_msg_history - if isinstance(validated_msg_history, ReAsk): - raise ValidatorError( - f"Message history validation failed: " - f"{validated_msg_history}" - ) - if validated_msg_history != msg_str: - raise ValidatorError("Message history validation failed") + msg_history = self.prepare_msg_history( + call_log, msg_history, prompt_params, msg_history_schema + ) elif prompt is not None: if msg_history_schema is not None: raise UserFacingException( @@ -371,57 +471,17 @@ def prepare( "not supported when using prompt/instructions." ) ) - if isinstance(prompt, str): - prompt = Prompt(prompt) - - prompt = prompt.format(**prompt_params) - - # TODO(shreya): should there be any difference - # to parsing params for prompt? - if instructions is not None and isinstance(instructions, Instructions): - instructions = instructions.format(**prompt_params) - - instructions, prompt = output_schema.preprocess_prompt( - api, instructions, prompt + msg_history = None + instructions, prompt = self.prepare_prompt( + call_log, + instructions, + prompt, + prompt_params, + api, + prompt_schema, + instructions_schema, + output_schema, ) - - # validate prompt - if prompt_schema is not None and prompt is not None: - inputs = Inputs( - llm_output=prompt.source, - ) - iteration = Iteration(inputs=inputs) - call_log.iterations.insert(0, iteration) - validated_prompt = prompt_schema.validate( - iteration, prompt.source, self.metadata - ) - iteration.outputs.validation_output = validated_prompt - if validated_prompt is None: - raise ValidatorError("Prompt validation failed") - if isinstance(validated_prompt, ReAsk): - raise ValidatorError( - f"Prompt validation failed: {validated_prompt}" - ) - prompt = Prompt(validated_prompt) - - # validate instructions - if instructions_schema is not None and instructions is not None: - inputs = Inputs( - llm_output=instructions.source, - ) - iteration = Iteration(inputs=inputs) - call_log.iterations.insert(0, iteration) - validated_instructions = instructions_schema.validate( - iteration, instructions.source, self.metadata - ) - iteration.outputs.validation_output = validated_instructions - if validated_instructions is None: - raise ValidatorError("Instructions validation failed") - if isinstance(validated_instructions, ReAsk): - raise ValidatorError( - f"Instructions validation failed: {validated_instructions}" - ) - instructions = Instructions(validated_instructions) else: raise UserFacingException( ValueError("Prompt or message history must be provided.") From 6d03e84bcfa469a47d9013631f80cebf9a5aa055 Mon Sep 17 00:00:00 2001 From: Rafael Irgolic Date: Thu, 7 Dec 2023 14:42:13 +0000 Subject: [PATCH 11/11] Add input validation notebook --- docs/examples/input_validation.ipynb | 123 +++++++++++++++++++++++++++ 1 file changed, 123 insertions(+) create mode 100644 docs/examples/input_validation.ipynb diff --git a/docs/examples/input_validation.ipynb b/docs/examples/input_validation.ipynb new file mode 100644 index 000000000..527168673 --- /dev/null +++ b/docs/examples/input_validation.ipynb @@ -0,0 +1,123 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Input Validation\n", + "\n", + "Guardrails supports validating inputs (prompts, instructions, msg_history) with string validators." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In XML, specify the validators on the `prompt` or `instructions` tag, as such:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "is_executing": true + }, + "outputs": [], + "source": [ + "rail_spec = \"\"\"\n", + "\n", + "\n", + "This is not two words\n", + "\n", + "\n", + "\n", + "\n", + "\"\"\"\n", + "\n", + "from guardrails import Guard\n", + "guard = Guard.from_rail_string(rail_spec)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "When `fix` is specified as the on-fail handler, the prompt will automatically be amended before calling the LLM.\n", + "\n", + "In any other case (for example, `exception`), a `ValidationException` will be returned in the outcome." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "is_executing": true + }, + "outputs": [], + "source": [ + "import openai\n", + "\n", + "outcome = guard(\n", + " openai.ChatCompletion.create,\n", + ")\n", + "outcome.error" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "When using pydantic to initialize a `Guard`, input validators can be specified by composition, as such:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from guardrails.validators import TwoWords\n", + "from pydantic import BaseModel\n", + "\n", + "\n", + "class Pet(BaseModel):\n", + " name: str\n", + " age: int\n", + "\n", + "\n", + "guard = Guard.from_pydantic(Pet)\n", + "guard.with_prompt_validation([TwoWords(on_fail=\"exception\")])\n", + "\n", + "outcome = guard(\n", + " openai.ChatCompletion.create,\n", + " prompt=\"This is not two words\",\n", + ")\n", + "outcome.error" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.0" + } + }, + "nbformat": 4, + "nbformat_minor": 1 +}