|
5 | 5 | from copy import deepcopy
|
6 | 6 | from loguru import logger
|
7 | 7 | from pydantic import AliasChoices, BaseModel, Field
|
8 |
| -from typing import Dict, List, Optional, Union |
| 8 | +from typing import Dict, List, Optional, Union, Type |
9 | 9 |
|
10 | 10 | from common.utils import unwrap, prune_dict
|
11 | 11 |
|
@@ -176,6 +176,10 @@ class BaseSamplerRequest(BaseModel):
|
176 | 176 | default_factory=lambda: get_default_sampler_value("json_schema"),
|
177 | 177 | )
|
178 | 178 |
|
| 179 | + pydantic_model: Optional[Type[BaseModel]] = Field( |
| 180 | + default_factory=lambda: get_default_sampler_value("pydantic_model"), |
| 181 | + ) |
| 182 | + |
179 | 183 | regex_pattern: Optional[str] = Field(
|
180 | 184 | default_factory=lambda: get_default_sampler_value("regex_pattern"),
|
181 | 185 | )
|
@@ -329,6 +333,7 @@ def to_gen_params(self, **kwargs):
|
329 | 333 | "cfg_scale": self.cfg_scale,
|
330 | 334 | "negative_prompt": self.negative_prompt,
|
331 | 335 | "json_schema": self.json_schema,
|
| 336 | + "pydantic_model": self.pydantic_model, |
332 | 337 | "regex_pattern": self.regex_pattern,
|
333 | 338 | "grammar_string": self.grammar_string,
|
334 | 339 | "speculative_ngram": self.speculative_ngram,
|
|
0 commit comments