Skip to content

Commit 26aed9c

Browse files
hmellorMu Huai
authored and
Mu Huai
committed
Simplify (and fix) passing of guided decoding backend options (vllm-project#17008)
Signed-off-by: Harry Mellor <[email protected]> Signed-off-by: Mu Huai <[email protected]>
1 parent f7f39e7 commit 26aed9c

File tree

17 files changed

+309
-217
lines changed

17 files changed

+309
-217
lines changed

examples/online_serving/openai_chat_completion_structured_outputs.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,8 @@ def extra_backend_options_completion(client: OpenAI, model: str):
112112
113113

114114
try:
115-
# The no-fallback option forces vLLM to use xgrammar, so when it fails
116-
# you get a 400 with the reason why
115+
# The guided_decoding_disable_fallback option forces vLLM to use
116+
# xgrammar, so when it fails you get a 400 with the reason why
117117
completion = client.chat.completions.create(
118118
model=model,
119119
messages=[{
@@ -123,7 +123,8 @@ def extra_backend_options_completion(client: OpenAI, model: str):
123123
extra_body={
124124
"guided_regex": r"\w+@\w+\.com\n",
125125
"stop": ["\n"],
126-
"guided_decoding_backend": "xgrammar:no-fallback"
126+
"guided_decoding_backend": "xgrammar",
127+
"guided_decoding_disable_fallback": True,
127128
},
128129
)
129130
return completion.choices[0].message.content

tests/entrypoints/llm/test_guided_generate.py

Lines changed: 125 additions & 81 deletions
Large diffs are not rendered by default.

tests/model_executor/test_guided_processors.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -202,12 +202,15 @@ def test_multiple_guided_options_not_allowed(sample_json_schema, sample_regex):
202202

203203
def test_guided_decoding_backend_options():
204204
"""Test backend-specific options"""
205-
params = GuidedDecodingParams(
206-
backend="xgrammar:option-1,option-2,option-3")
207-
assert params.backend_options() == ["option-1", "option-2", "option-3"]
208-
209-
no_fallback = GuidedDecodingParams(backend="xgrammar:option-1,no-fallback")
210-
assert no_fallback.no_fallback()
205+
with pytest.warns(DeprecationWarning):
206+
guided_decoding_params = GuidedDecodingParams(
207+
backend=
208+
"xgrammar:no-fallback,disable-any-whitespace,no-additional-properties"
209+
)
210+
assert guided_decoding_params.backend == "xgrammar"
211+
assert guided_decoding_params.disable_fallback
212+
assert guided_decoding_params.disable_any_whitespace
213+
assert guided_decoding_params.disable_additional_properties
211214

212215

213216
def test_pickle_xgrammar_tokenizer_data():

tests/v1/entrypoints/llm/test_struct_output_generate.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,12 @@
1717
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
1818

1919
PARAMS_MODELS_BACKENDS_TOKENIZER_MODE = [
20-
("mistralai/Ministral-8B-Instruct-2410", "xgrammar:disable-any-whitespace",
21-
"auto"),
22-
("mistralai/Ministral-8B-Instruct-2410", "guidance:disable-any-whitespace",
23-
"auto"),
24-
("mistralai/Ministral-8B-Instruct-2410", "xgrammar:disable-any-whitespace",
25-
"mistral"),
26-
("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar:disable-any-whitespace", "auto"),
20+
("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "auto"),
21+
("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto"),
22+
("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "mistral"),
23+
("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto"),
2724
#FIXME: This test is flaky on CI thus disabled
28-
#("Qwen/Qwen2.5-1.5B-Instruct", "guidance:disable-any-whitespace", "auto"),
25+
#("Qwen/Qwen2.5-1.5B-Instruct", "guidance", "auto"),
2926
]
3027

3128
PARAMS_MODELS_TOKENIZER_MODE = [
@@ -73,6 +70,7 @@ def test_structured_output(
7370
enforce_eager=enforce_eager,
7471
max_model_len=1024,
7572
guided_decoding_backend=guided_decoding_backend,
73+
guided_decoding_disable_any_whitespace=True,
7674
tokenizer_mode=tokenizer_mode)
7775

7876
#
@@ -98,8 +96,7 @@ def test_structured_output(
9896

9997
generated_text = output.outputs[0].text
10098
assert generated_text is not None
101-
if 'disable-any-whitespace' in guided_decoding_backend:
102-
assert "\n" not in generated_text
99+
assert "\n" not in generated_text
103100
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
104101
output_json = json.loads(generated_text)
105102
jsonschema.validate(instance=output_json, schema=sample_json_schema)
@@ -520,10 +517,11 @@ def test_structured_output_auto_mode(
520517
def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch):
521518
monkeypatch.setenv("VLLM_USE_V1", "1")
522519

523-
backend = 'guidance:no-additional-properties,disable-any-whitespace'
524520
llm = LLM(model="Qwen/Qwen2.5-1.5B-Instruct",
525521
max_model_len=1024,
526-
guided_decoding_backend=backend)
522+
guided_decoding_backend="guidance",
523+
guided_decoding_disable_any_whitespace=True,
524+
guided_decoding_disable_additional_properties=True)
527525

528526
schema = {
529527
'type': 'object',
@@ -548,7 +546,11 @@ def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch):
548546
"<|im_end|>\n<|im_start|>assistant\n")
549547

550548
def generate_with_backend(backend):
551-
guided_params = GuidedDecodingParams(json=schema, backend=backend)
549+
guided_params = GuidedDecodingParams(
550+
json=schema,
551+
backend=backend,
552+
disable_any_whitespace=True,
553+
disable_additional_properties=True)
552554
sampling_params = SamplingParams(temperature=0,
553555
max_tokens=256,
554556
guided_decoding=guided_params)
@@ -562,8 +564,7 @@ def generate_with_backend(backend):
562564
jsonschema.validate(instance=parsed_json, schema=schema)
563565
return parsed_json
564566

565-
generated = generate_with_backend(
566-
'guidance:no-additional-properties,disable-any-whitespace')
567+
generated = generate_with_backend("guidance")
567568
assert "a1" in generated
568569
assert "a2" in generated
569570
assert "a3" in generated

tests/v1/test_oracle.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ def test_unsupported_configs(monkeypatch):
5757
with pytest.raises(NotImplementedError):
5858
AsyncEngineArgs(
5959
model=MODEL,
60-
guided_decoding_backend="lm-format-enforcer:no-fallback",
60+
guided_decoding_backend="lm-format-enforcer",
61+
guided_decoding_disable_fallback=True,
6162
).create_engine_config()
6263

6364
with pytest.raises(NotImplementedError):

vllm/config.py

Lines changed: 60 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,14 @@
1717
from importlib.util import find_spec
1818
from pathlib import Path
1919
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal,
20-
Optional, Protocol, TypeVar, Union, get_args, get_origin)
20+
Optional, Protocol, TypeVar, Union, cast, get_args,
21+
get_origin)
2122

2223
import torch
2324
from pydantic import BaseModel, Field, PrivateAttr
2425
from torch.distributed import ProcessGroup, ReduceOp
2526
from transformers import PretrainedConfig
27+
from typing_extensions import deprecated
2628

2729
import vllm.envs as envs
2830
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
@@ -32,7 +34,6 @@
3234
get_quantization_config)
3335
from vllm.model_executor.models import ModelRegistry
3436
from vllm.platforms import CpuArchEnum, current_platform
35-
from vllm.sampling_params import GuidedDecodingParams
3637
from vllm.tracing import is_otel_available, otel_import_error_traceback
3738
from vllm.transformers_utils.config import (
3839
ConfigFormat, get_config, get_hf_image_processor_config,
@@ -344,7 +345,7 @@ def compute_hash(self) -> str:
344345
def __init__(
345346
self,
346347
model: str,
347-
task: Union[TaskOption, Literal["draft"]],
348+
task: Literal[TaskOption, Literal["draft"]],
348349
tokenizer: str,
349350
tokenizer_mode: str,
350351
trust_remote_code: bool,
@@ -701,7 +702,7 @@ def _get_preferred_task(
701702

702703
def _resolve_task(
703704
self,
704-
task_option: Union[TaskOption, Literal["draft"]],
705+
task_option: Literal[TaskOption, Literal["draft"]],
705706
) -> tuple[set[_ResolvedTask], _ResolvedTask]:
706707
if task_option == "draft":
707708
return {"draft"}, "draft"
@@ -3185,13 +3186,36 @@ def get_served_model_name(model: str,
31853186
class DecodingConfig:
31863187
"""Dataclass which contains the decoding strategy of the engine."""
31873188

3188-
guided_decoding_backend: GuidedDecodingBackend = \
3189-
"auto" if envs.VLLM_USE_V1 else "xgrammar"
3189+
@property
3190+
@deprecated(
3191+
"`guided_decoding_backend` is deprecated and has been renamed to "
3192+
"`backend`. This will be removed in v0.10.0. Please use the "
3193+
"`backend` argument instead.")
3194+
def guided_decoding_backend(self) -> GuidedDecodingBackend:
3195+
return self.backend
3196+
3197+
@guided_decoding_backend.setter
3198+
def guided_decoding_backend(self, value: GuidedDecodingBackend):
3199+
self.backend = value
3200+
3201+
backend: GuidedDecodingBackend = "auto" if envs.VLLM_USE_V1 else "xgrammar"
31903202
"""Which engine will be used for guided decoding (JSON schema / regex etc)
31913203
by default. With "auto", we will make opinionated choices based on request
31923204
contents and what the backend libraries currently support, so the behavior
31933205
is subject to change in each release."""
31943206

3207+
disable_fallback: bool = False
3208+
"""If `True`, vLLM will not fallback to a different backend on error."""
3209+
3210+
disable_any_whitespace: bool = False
3211+
"""If `True`, the model will not generate any whitespace during guided
3212+
decoding. This is only supported for xgrammar and guidance backends."""
3213+
3214+
disable_additional_properties: bool = False
3215+
"""If `True`, the `guidance` backend will not use `additionalProperties`
3216+
in the JSON schema. This is only supported for the `guidance` backend and
3217+
is used to better align its behaviour with `outlines` and `xgrammar`."""
3218+
31953219
reasoning_backend: Optional[str] = None
31963220
"""Select the reasoning parser depending on the model that you're using.
31973221
This is used to parse the reasoning content into OpenAI API format.
@@ -3217,15 +3241,41 @@ def compute_hash(self) -> str:
32173241
return hash_str
32183242

32193243
def __post_init__(self):
3220-
backend = GuidedDecodingParams(
3221-
backend=self.guided_decoding_backend).backend_name
3244+
if ":" in self.backend:
3245+
self._extract_backend_options()
3246+
32223247
if envs.VLLM_USE_V1:
32233248
valid_guided_backends = get_args(GuidedDecodingBackendV1)
32243249
else:
32253250
valid_guided_backends = get_args(GuidedDecodingBackendV0)
3226-
if backend not in valid_guided_backends:
3227-
raise ValueError(f"Invalid guided_decoding_backend '{backend}',"
3251+
if self.backend not in valid_guided_backends:
3252+
raise ValueError(f"Invalid backend '{self.backend}',"
32283253
f" must be one of {valid_guided_backends}")
3254+
if (self.disable_any_whitespace
3255+
and self.backend not in ("xgrammar", "guidance")):
3256+
raise ValueError("disable_any_whitespace is only supported for "
3257+
"xgrammar and guidance backends.")
3258+
if (self.disable_additional_properties and self.backend != "guidance"):
3259+
raise ValueError("disable_additional_properties is only supported "
3260+
"for the guidance backend.")
3261+
3262+
@deprecated(
3263+
"Passing guided decoding backend options inside backend in the format "
3264+
"'backend:...' is deprecated. This will be removed in v0.10.0. Please "
3265+
"use the dedicated arguments '--disable-fallback', "
3266+
"'--disable-any-whitespace' and '--disable-additional-properties' "
3267+
"instead.")
3268+
def _extract_backend_options(self):
3269+
"""Extract backend options from the backend string."""
3270+
backend, options = self.backend.split(":")
3271+
self.backend = cast(GuidedDecodingBackend, backend)
3272+
options_set = set(options.strip().split(","))
3273+
if "no-fallback" in options_set:
3274+
self.disable_fallback = True
3275+
if "disable-any-whitespace" in options_set:
3276+
self.disable_any_whitespace = True
3277+
if "no-additional-properties" in options_set:
3278+
self.disable_additional_properties = True
32293279

32303280

32313281
@dataclass

vllm/engine/arg_utils.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818
from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
1919
ConfigFormat, ConfigType, DecodingConfig, Device,
2020
DeviceConfig, DistributedExecutorBackend,
21-
GuidedDecodingBackendV1, HfOverrides,
22-
KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig,
23-
ModelConfig, ModelImpl, MultiModalConfig,
21+
GuidedDecodingBackend, GuidedDecodingBackendV1,
22+
HfOverrides, KVTransferConfig, LoadConfig, LoadFormat,
23+
LoRAConfig, ModelConfig, ModelImpl, MultiModalConfig,
2424
ObservabilityConfig, ParallelConfig, PoolerConfig,
2525
PrefixCachingHashAlgo, PromptAdapterConfig,
2626
SchedulerConfig, SchedulerPolicy, SpeculativeConfig,
@@ -317,7 +317,12 @@ class EngineArgs:
317317
bool] = SchedulerConfig.enable_chunked_prefill
318318
disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input
319319

320-
guided_decoding_backend: str = DecodingConfig.guided_decoding_backend
320+
guided_decoding_backend: GuidedDecodingBackend = DecodingConfig.backend
321+
guided_decoding_disable_fallback: bool = DecodingConfig.disable_fallback
322+
guided_decoding_disable_any_whitespace: bool = \
323+
DecodingConfig.disable_any_whitespace
324+
guided_decoding_disable_additional_properties: bool = \
325+
DecodingConfig.disable_additional_properties
321326
logits_processor_pattern: Optional[str] = None
322327

323328
speculative_config: Optional[Dict[str, Any]] = None
@@ -498,9 +503,17 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
498503
title="DecodingConfig",
499504
description=DecodingConfig.__doc__,
500505
)
506+
guided_decoding_group.add_argument("--guided-decoding-backend",
507+
**guided_decoding_kwargs["backend"])
501508
guided_decoding_group.add_argument(
502-
'--guided-decoding-backend',
503-
**guided_decoding_kwargs["guided_decoding_backend"])
509+
"--guided-decoding-disable-fallback",
510+
**guided_decoding_kwargs["disable_fallback"])
511+
guided_decoding_group.add_argument(
512+
"--guided-decoding-disable-any-whitespace",
513+
**guided_decoding_kwargs["disable_any_whitespace"])
514+
guided_decoding_group.add_argument(
515+
"--guided-decoding-disable-additional-properties",
516+
**guided_decoding_kwargs["disable_additional_properties"])
504517
guided_decoding_group.add_argument(
505518
"--reasoning-parser",
506519
# This choices is a special case because it's not static
@@ -1244,7 +1257,11 @@ def create_engine_config(
12441257
if self.enable_prompt_adapter else None
12451258

12461259
decoding_config = DecodingConfig(
1247-
guided_decoding_backend=self.guided_decoding_backend,
1260+
backend=self.guided_decoding_backend,
1261+
disable_fallback=self.guided_decoding_disable_fallback,
1262+
disable_any_whitespace=self.guided_decoding_disable_any_whitespace,
1263+
disable_additional_properties=\
1264+
self.guided_decoding_disable_additional_properties,
12481265
reasoning_backend=self.reasoning_parser
12491266
if self.enable_reasoning else None,
12501267
)
@@ -1335,9 +1352,8 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
13351352
recommend_to_remove=True)
13361353
return False
13371354

1338-
# remove backend options when doing this check
1339-
if self.guided_decoding_backend.split(':')[0] \
1340-
not in get_args(GuidedDecodingBackendV1):
1355+
if self.guided_decoding_backend not in get_args(
1356+
GuidedDecodingBackendV1):
13411357
_raise_or_fallback(
13421358
feature_name=
13431359
f"--guided-decoding-backend={self.guided_decoding_backend}",

vllm/engine/llm_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2091,7 +2091,7 @@ def _build_logits_processors(
20912091

20922092
tokenizer = self.get_tokenizer(lora_request=lora_request)
20932093
guided_decoding.backend = guided_decoding.backend or \
2094-
self.decoding_config.guided_decoding_backend
2094+
self.decoding_config.backend
20952095

20962096
if self.decoding_config.reasoning_backend is not None:
20972097
logger.debug("Building with reasoning backend %s",

vllm/engine/multiprocessing/client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -615,9 +615,9 @@ async def _process_request(
615615
build_guided_decoding_logits_processor_async(
616616
sampling_params=params,
617617
tokenizer=await self.get_tokenizer(lora_request),
618-
default_guided_backend=(self.decoding_config.guided_decoding_backend
618+
default_guided_backend=(self.decoding_config.backend
619619
if self.decoding_config
620-
else DecodingConfig.guided_decoding_backend),
620+
else DecodingConfig.backend),
621621
model_config=self.model_config,
622622
reasoning_backend=self.decoding_config.reasoning_backend,
623623
)

0 commit comments

Comments
 (0)