Skip to content

Commit e75d764

Browse files
committed
adding pydantic sampler for use with strict mode
1 parent 50f20a6 commit e75d764

File tree

3 files changed

+43
-2
lines changed

3 files changed

+43
-2
lines changed

backends/exllamav2/grammar.py

+30-1
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77
build_token_enforcer_tokenizer_data,
88
)
99
from loguru import logger
10-
from typing import List
10+
from typing import List, Type
1111
from functools import lru_cache
12+
from pydantic import BaseModel
1213

1314

1415
class OutlinesTokenizerWrapper:
@@ -106,6 +107,34 @@ def add_json_schema_filter(
106107
# Append the filters
107108
self.filters.extend([lmfilter, prefix_filter])
108109

110+
def add_pydantic_filter(
111+
self,
112+
pydantic_model: Type[BaseModel],
113+
model: ExLlamaV2,
114+
tokenizer:ExLlamaV2Tokenizer
115+
):
116+
"""Adds an ExllamaV2 filter based on a Pydantic model"""
117+
# Create the parser
118+
try:
119+
schema_parser = JsonSchemaParser(pydantic_model.model_json_schema())
120+
except Exception:
121+
traceback.print_exc()
122+
logger.error(
123+
"Skipping because the pydantic model couldn't be used. "
124+
"Please read the above error for more information."
125+
)
126+
127+
return
128+
129+
json_prefixes = ["[", "{"]
130+
131+
lmfilter = ExLlamaV2TokenEnforcerFilter(
132+
schema_parser, _get_lmfe_tokenizer_data(tokenizer)
133+
)
134+
prefix_filter = ExLlamaV2PrefixFilter(model, tokenizer, json_prefixes)
135+
136+
self.filters.extend([lmfilter, prefix_filter])
137+
109138
def add_regex_filter(
110139
self,
111140
pattern: str,

backends/exllamav2/model.py

+7
Original file line numberDiff line numberDiff line change
@@ -1077,6 +1077,13 @@ async def generate_gen(
10771077
json_schema, self.model, self.tokenizer
10781078
)
10791079

1080+
# Add pydantic filter if it exists
1081+
pydantic_model = unwrap(kwargs.get("pydantic_model"))
1082+
if pydantic_model:
1083+
grammar_handler.add_pydantic_filter(
1084+
pydantic_model, self.model, self.tokenizer
1085+
)
1086+
10801087
# Add regex filter if it exists
10811088
regex_pattern = unwrap(kwargs.get("regex_pattern"))
10821089
if regex_pattern:

common/sampling.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from copy import deepcopy
66
from loguru import logger
77
from pydantic import AliasChoices, BaseModel, Field
8-
from typing import Dict, List, Optional, Union
8+
from typing import Dict, List, Optional, Union, Type
99

1010
from common.utils import unwrap, prune_dict
1111

@@ -176,6 +176,10 @@ class BaseSamplerRequest(BaseModel):
176176
default_factory=lambda: get_default_sampler_value("json_schema"),
177177
)
178178

179+
pydantic_model: Optional[Type[BaseModel]] = Field(
180+
default_factory=lambda: get_default_sampler_value("pydantic_model"),
181+
)
182+
179183
regex_pattern: Optional[str] = Field(
180184
default_factory=lambda: get_default_sampler_value("regex_pattern"),
181185
)
@@ -329,6 +333,7 @@ def to_gen_params(self, **kwargs):
329333
"cfg_scale": self.cfg_scale,
330334
"negative_prompt": self.negative_prompt,
331335
"json_schema": self.json_schema,
336+
"pydantic_model": self.pydantic_model,
332337
"regex_pattern": self.regex_pattern,
333338
"grammar_string": self.grammar_string,
334339
"speculative_ngram": self.speculative_ngram,

0 commit comments

Comments
 (0)