Skip to content

Commit e3a9a73

Browse files
committed
remove duplication of guided backends + add new options
Signed-off-by: Jannis Schönleber <[email protected]>
1 parent c5da1c5 commit e3a9a73

File tree

2 files changed

+13
-12
lines changed

2 files changed

+13
-12
lines changed

vllm/engine/arg_utils.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from vllm.transformers_utils.utils import check_gguf_file
2929
from vllm.usage.usage_lib import UsageContext
3030
from vllm.utils import FlexibleArgumentParser, StoreBoolean, is_in_ray_actor
31+
from vllm.v1.engine.processor import SUPPORTED_GUIDED_DECODING
3132

3233
if TYPE_CHECKING:
3334
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
@@ -1379,11 +1380,6 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
13791380
recommend_to_remove=False)
13801381
return False
13811382

1382-
# Xgrammar and Guidance are supported.
1383-
SUPPORTED_GUIDED_DECODING = [
1384-
"xgrammar", "xgrammar:disable-any-whitespace", "guidance",
1385-
"guidance:disable-any-whitespace", "auto"
1386-
]
13871383
if self.guided_decoding_backend not in SUPPORTED_GUIDED_DECODING:
13881384
_raise_or_fallback(feature_name="--guided-decoding-backend",
13891385
recommend_to_remove=False)

vllm/v1/engine/processor.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,14 @@
2323
from vllm.v1.structured_output.utils import (
2424
validate_structured_output_request_xgrammar)
2525

26+
# Xgrammar and Guidance are supported.
27+
SUPPORTED_GUIDED_DECODING = [
28+
"xgrammar", "xgrammar:disable-any-whitespace", "guidance",
29+
"guidance:disable-any-whitespace", "guidance:no-additional-properties",
30+
"guidance:no-additional-properties,disable-any-whitespace",
31+
"guidance:disable-any-whitespace,no-additional-properties", "auto"
32+
]
33+
2634

2735
class Processor:
2836

@@ -120,14 +128,11 @@ def _validate_structured_output(self, params: SamplingParams) -> None:
120128
if not params.guided_decoding or not self.decoding_config:
121129
return
122130

123-
supported_backends = [
124-
"xgrammar", "xgrammar:disable-any-whitespace", "guidance",
125-
"guidance:disable-any-whitespace", "auto"
126-
]
127131
engine_level_backend = self.decoding_config.guided_decoding_backend
128-
if engine_level_backend not in supported_backends:
129-
raise ValueError(f"Only {supported_backends} structured output is "
130-
"supported in V1.")
132+
if engine_level_backend not in SUPPORTED_GUIDED_DECODING:
133+
raise ValueError(
134+
f"Only {SUPPORTED_GUIDED_DECODING} structured output is "
135+
"supported in V1.")
131136
if params.guided_decoding.backend:
132137
if params.guided_decoding.backend != engine_level_backend:
133138
raise ValueError("Request-level structured output backend "

0 commit comments

Comments
 (0)