Skip to content

Commit 991a8f2

Browse files
authored
Merge pull request #517 from guardrails-ai/skeleton-reask-engineering
Skeleton reask prompt engineering
2 parents 11442aa + b41de2e commit 991a8f2

File tree

6 files changed

+98
-5
lines changed

6 files changed

+98
-5
lines changed

guardrails/constants.xml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,12 @@ ${output_schema}
5252
ONLY return a valid JSON object (no other text is necessary), where the key of the field in JSON is the `name` attribute of the corresponding XML, and the value is of the type specified by the corresponding XML's tag. The JSON MUST conform to the XML format, including any types and format requests e.g. requests for lists, objects and specific types. Be correct and concise. If you are unsure anywhere, enter `null`.
5353
</json_suffix_without_examples>
5454

55+
<json_suffix_with_structure_example>
56+
${gr.json_suffix_without_examples}
57+
Here's an example of the structure:
58+
${json_example}
59+
</json_suffix_with_structure_example>
60+
5561
<complete_json_suffix>
5662
Given below is XML that describes the information to extract from this document and the tags to extract it into.
5763

guardrails/datatypes.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,9 @@ def __init__(
7171
self.description = description
7272
self.optional = optional
7373

74+
def get_example(self):
75+
raise NotImplementedError
76+
7477
@property
7578
def validators(self) -> TypedList:
7679
return self.validators_attr.validators
@@ -188,6 +191,9 @@ class String(ScalarType):
188191

189192
tag = "string"
190193

194+
def get_example(self):
195+
return "string"
196+
191197
def from_str(self, s: str) -> Optional[str]:
192198
"""Create a String from a string."""
193199
return to_string(s)
@@ -214,6 +220,9 @@ class Integer(ScalarType):
214220

215221
tag = "integer"
216222

223+
def get_example(self):
224+
return 1
225+
217226
def from_str(self, s: str) -> Optional[int]:
218227
"""Create an Integer from a string."""
219228
return to_int(s)
@@ -225,6 +234,9 @@ class Float(ScalarType):
225234

226235
tag = "float"
227236

237+
def get_example(self):
238+
return 1.5
239+
228240
def from_str(self, s: str) -> Optional[float]:
229241
"""Create a Float from a string."""
230242
return to_float(s)
@@ -236,6 +248,9 @@ class Boolean(ScalarType):
236248

237249
tag = "bool"
238250

251+
def get_example(self):
252+
return True
253+
239254
def from_str(self, s: Union[str, bool]) -> Optional[bool]:
240255
"""Create a Boolean from a string."""
241256
if s is None:
@@ -273,6 +288,9 @@ def __init__(
273288
super().__init__(children, validators_attr, optional, name, description)
274289
self.date_format = None
275290

291+
def get_example(self):
292+
return datetime.date.today()
293+
276294
def from_str(self, s: str) -> Optional[datetime.date]:
277295
"""Create a Date from a string."""
278296
if s is None:
@@ -312,6 +330,9 @@ def __init__(
312330
self.time_format = "%H:%M:%S"
313331
super().__init__(children, validators_attr, optional, name, description)
314332

333+
def get_example(self):
334+
return datetime.time()
335+
315336
def from_str(self, s: str) -> Optional[datetime.time]:
316337
"""Create a Time from a string."""
317338
if s is None:
@@ -340,6 +361,9 @@ def __init__(self, *args, **kwargs):
340361
super().__init__(*args, **kwargs)
341362
deprecate_type(type(self))
342363

364+
def get_example(self):
365+
366+
343367

344368
@deprecate_type
345369
@register_type("url")
@@ -352,6 +376,9 @@ def __init__(self, *args, **kwargs):
352376
super().__init__(*args, **kwargs)
353377
deprecate_type(type(self))
354378

379+
def get_example(self):
380+
return "https://example.com"
381+
355382

356383
@deprecate_type
357384
@register_type("pythoncode")
@@ -364,6 +391,9 @@ def __init__(self, *args, **kwargs):
364391
super().__init__(*args, **kwargs)
365392
deprecate_type(type(self))
366393

394+
def get_example(self):
395+
return "print('hello world')"
396+
367397

368398
@deprecate_type
369399
@register_type("sql")
@@ -376,13 +406,19 @@ def __init__(self, *args, **kwargs):
376406
super().__init__(*args, **kwargs)
377407
deprecate_type(type(self))
378408

409+
def get_example(self):
410+
return "SELECT * FROM table"
411+
379412

380413
@register_type("percentage")
381414
class Percentage(ScalarType):
382415
"""Element tag: `<percentage>`"""
383416

384417
tag = "percentage"
385418

419+
def get_example(self):
420+
return "20%"
421+
386422

387423
@register_type("enum")
388424
class Enum(ScalarType):
@@ -402,6 +438,9 @@ def __init__(
402438
super().__init__(children, validators_attr, optional, name, description)
403439
self.enum_values = enum_values
404440

441+
def get_example(self):
442+
return self.enum_values[0]
443+
405444
def from_str(self, s: str) -> Optional[str]:
406445
"""Create an Enum from a string."""
407446
if s is None:
@@ -434,6 +473,9 @@ class List(NonScalarType):
434473

435474
tag = "list"
436475

476+
def get_example(self):
477+
return [e.get_example() for e in self._children.values()]
478+
437479
def collect_validation(
438480
self,
439481
key: str,
@@ -476,6 +518,9 @@ class Object(NonScalarType):
476518

477519
tag = "object"
478520

521+
def get_example(self):
522+
return {k: v.get_example() for k, v in self._children.items()}
523+
479524
def collect_validation(
480525
self,
481526
key: str,
@@ -546,6 +591,14 @@ def __init__(
546591
super().__init__(children, validators_attr, optional, name, description)
547592
self.discriminator_key = discriminator_key
548593

594+
def get_example(self):
595+
first_discriminator = list(self._children.keys())[0]
596+
first_child = list(self._children.values())[0]
597+
return {
598+
self.discriminator_key: first_discriminator,
599+
**first_child.get_example(),
600+
}
601+
549602
@classmethod
550603
def from_xml(cls, element: ET._Element, strict: bool = False, **kwargs) -> Self:
551604
# grab `discriminator` attribute
@@ -606,6 +659,9 @@ def __init__(
606659
) -> None:
607660
super().__init__(children, validators_attr, optional, name, description)
608661

662+
def get_example(self):
663+
return {k: v.get_example() for k, v in self._children.items()}
664+
609665
def collect_validation(
610666
self,
611667
key: str,

guardrails/schema.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def check_valid_reask_prompt(self, reask_prompt: Optional[str]) -> None:
219219

220220

221221
class JsonSchema(Schema):
222-
reask_prompt_vars = {"previous_response", "output_schema"}
222+
reask_prompt_vars = {"previous_response", "output_schema", "json_example"}
223223

224224
def __init__(
225225
self,
@@ -269,7 +269,7 @@ def get_reask_setup(
269269
if reask_prompt_template is None:
270270
reask_prompt_template = Prompt(
271271
constants["high_level_skeleton_reask_prompt"]
272-
+ constants["json_suffix_without_examples"]
272+
+ constants["json_suffix_with_structure_example"]
273273
)
274274

275275
# This is incorrect
@@ -300,6 +300,10 @@ def get_reask_setup(
300300
)
301301

302302
pruned_tree_string = pruned_tree_schema.transpile()
303+
json_example = json.dumps(
304+
pruned_tree_schema.root_datatype.get_example(),
305+
indent=2,
306+
)
303307

304308
def reask_decoder(obj):
305309
decoded = {}
@@ -317,6 +321,7 @@ def reask_decoder(obj):
317321
reask_value, indent=2, default=reask_decoder, ensure_ascii=False
318322
),
319323
output_schema=pruned_tree_string,
324+
json_example=json_example,
320325
**(prompt_params or {}),
321326
)
322327

tests/integration_tests/test_assets/entity_extraction/compiled_prompt_skeleton_reask_2.txt

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ I was given the following JSON response, which had problems due to incorrect val
6969

7070
Help me correct the incorrect values based on the given error messages.
7171

72+
7273
Given below is XML that describes the information to extract from this document and the tags to extract it into.
7374

7475
<output>
@@ -85,6 +86,18 @@ Given below is XML that describes the information to extract from this document
8586

8687
ONLY return a valid JSON object (no other text is necessary), where the key of the field in JSON is the `name` attribute of the corresponding XML, and the value is of the type specified by the corresponding XML's tag. The JSON MUST conform to the XML format, including any types and format requests e.g. requests for lists, objects and specific types. Be correct and concise. If you are unsure anywhere, enter `null`.
8788

89+
Here's an example of the structure:
90+
{
91+
"fees": [
92+
{
93+
"name": "string",
94+
"explanation": "string",
95+
"value": 1.5
96+
}
97+
],
98+
"interest_rates": {}
99+
}
100+
88101

89102
Json Output:
90103

tests/integration_tests/test_assets/pydantic/msg_compiled_prompt_reask.txt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ I was given the following JSON response, which had problems due to incorrect val
1313

1414
Help me correct the incorrect values based on the given error messages.
1515

16+
1617
Given below is XML that describes the information to extract from this document and the tags to extract it into.
1718

1819
<output>
@@ -23,3 +24,10 @@ Given below is XML that describes the information to extract from this document
2324

2425

2526
ONLY return a valid JSON object (no other text is necessary), where the key of the field in JSON is the `name` attribute of the corresponding XML, and the value is of the type specified by the corresponding XML's tag. The JSON MUST conform to the XML format, including any types and format requests e.g. requests for lists, objects and specific types. Be correct and concise. If you are unsure anywhere, enter `null`.
27+
28+
Here's an example of the structure:
29+
{
30+
"name": "string",
31+
"director": "string",
32+
"release_year": 1
33+
}

tests/unit_tests/utils/test_reask_utils.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import pytest
44
from lxml import etree as ET
55

6-
from guardrails import Instructions, Prompt
76
from guardrails.classes.history.iteration import Iteration
87
from guardrails.datatypes import Object
98
from guardrails.schema import JsonSchema
@@ -443,10 +442,14 @@ def test_get_reask_prompt(
443442
444443
Help me correct the incorrect values based on the given error messages.
445444
445+
446446
Given below is XML that describes the information to extract from this document and the tags to extract it into.
447447
%s
448448
449449
ONLY return a valid JSON object (no other text is necessary), where the key of the field in JSON is the `name` attribute of the corresponding XML, and the value is of the type specified by the corresponding XML's tag. The JSON MUST conform to the XML format, including any types and format requests e.g. requests for lists, objects and specific types. Be correct and concise. If you are unsure anywhere, enter `null`.
450+
451+
Here's an example of the structure:
452+
%s
450453
""" # noqa: E501
451454
expected_instructions = """
452455
You are a helpful assistant only capable of communicating with valid JSON, and no other text.
@@ -467,13 +470,15 @@ def test_get_reask_prompt(
467470
result_prompt,
468471
instructions,
469472
) = output_schema.get_reask_setup(reasks, reask_json, False)
473+
json_example = output_schema.root_datatype.get_example()
470474

471-
assert result_prompt == Prompt(
475+
assert result_prompt.source == (
472476
expected_result_template
473477
% (
474478
json.dumps(reask_json, indent=2),
475479
expected_rail,
480+
json.dumps(json_example, indent=2),
476481
)
477482
)
478483

479-
assert instructions == Instructions(expected_instructions)
484+
assert instructions.source == expected_instructions

0 commit comments

Comments
 (0)