Skip to content

Commit 5a84d61

Browse files
joerundeshreyankg
authored andcommitted
[Frontend] Add backend-specific options for guided decoding (vllm-project#13505)
Signed-off-by: Joe Runde <[email protected]>
1 parent 500b058 commit 5a84d61

File tree

8 files changed

+123
-42
lines changed

8 files changed

+123
-42
lines changed

docs/source/features/structured_outputs.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ The following parameters are supported, which must be added as extra parameters:
1616
- `guided_json`: the output will follow the JSON schema.
1717
- `guided_grammar`: the output will follow the context free grammar.
1818
- `guided_whitespace_pattern`: used to override the default whitespace pattern for guided json decoding.
19-
- `guided_decoding_backend`: used to select the guided decoding backend to use.
19+
- `guided_decoding_backend`: used to select the guided decoding backend to use. Additional backend-specific options can be supplied in a comma separated list following a colon after the backend name. For example `"xgrammar:no-fallback"` will not allow vLLM to fallback to a different backend on error.
2020

2121
You can see the complete list of supported parameters on the [OpenAI-Compatible Server](#openai-compatible-server)page.
2222

examples/online_serving/openai_chat_completion_structured_outputs.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from enum import Enum
44

5-
from openai import OpenAI
5+
from openai import BadRequestError, OpenAI
66
from pydantic import BaseModel
77

88
client = OpenAI(
@@ -94,3 +94,26 @@ class CarDescription(BaseModel):
9494
extra_body={"guided_grammar": simplified_sql_grammar},
9595
)
9696
print(completion.choices[0].message.content)
97+
98+
# Extra backend options
99+
prompt = ("Generate an email address for Alan Turing, who works in Enigma."
100+
"End in .com and new line. Example result:"
101+
102+
103+
try:
104+
# The no-fallback option forces vLLM to use xgrammar, so when it fails
105+
# you get a 400 with the reason why
106+
completion = client.chat.completions.create(
107+
model="Qwen/Qwen2.5-3B-Instruct",
108+
messages=[{
109+
"role": "user",
110+
"content": prompt,
111+
}],
112+
extra_body={
113+
"guided_regex": "\w+@\w+\.com\n",
114+
"stop": ["\n"],
115+
"guided_decoding_backend": "xgrammar:no-fallback"
116+
},
117+
)
118+
except BadRequestError as e:
119+
print("This error is expected:", e)

tests/entrypoints/llm/test_guided_generate.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,22 @@ def test_validation_against_both_guided_decoding_options(sample_regex, llm):
280280
guided_options_request=dict(guided_regex=sample_regex))
281281

282282

283+
@pytest.mark.skip_global_cleanup
284+
def test_disable_guided_decoding_fallback(sample_regex, llm):
285+
sampling_params = SamplingParams(temperature=0.8,
286+
top_p=0.95,
287+
guided_decoding=GuidedDecodingParams(
288+
regex=sample_regex,
289+
backend="xgrammar:no-fallback"))
290+
291+
with pytest.raises(
292+
ValueError,
293+
match="xgrammar does not support regex guided decoding"):
294+
llm.generate(prompts="This should fail",
295+
sampling_params=sampling_params,
296+
use_tqdm=True)
297+
298+
283299
@pytest.mark.skip_global_cleanup
284300
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
285301
def test_guided_json_object(llm, guided_decoding_backend: str):

tests/model_executor/test_guided_processors.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,16 @@ def test_multiple_guided_options_not_allowed(sample_json_schema, sample_regex):
109109
GuidedDecodingParams(json=sample_json_schema, grammar="test grammar")
110110

111111

112+
def test_guided_decoding_backend_options():
113+
"""Test backend-specific options"""
114+
params = GuidedDecodingParams(
115+
backend="xgrammar:option-1,option-2,option-3")
116+
assert params.backend_options() == ["option-1", "option-2", "option-3"]
117+
118+
no_fallback = GuidedDecodingParams(backend="xgrammar:option-1,no-fallback")
119+
assert no_fallback.no_fallback()
120+
121+
112122
def test_pickle_xgrammar_tokenizer_data():
113123

114124
# TODO: move to another test file for xgrammar

vllm/config.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
get_quantization_config)
2626
from vllm.model_executor.models import ModelRegistry
2727
from vllm.platforms import CpuArchEnum
28+
from vllm.sampling_params import GuidedDecodingParams
2829
from vllm.tracing import is_otel_available, otel_import_error_traceback
2930
from vllm.transformers_utils.config import (
3031
ConfigFormat, get_config, get_hf_image_processor_config,
@@ -2633,7 +2634,9 @@ def compute_hash(self) -> str:
26332634

26342635
def __post_init__(self):
26352636
valid_guided_backends = ['outlines', 'lm-format-enforcer', 'xgrammar']
2636-
backend = self.guided_decoding_backend
2637+
2638+
backend = GuidedDecodingParams(
2639+
backend=self.guided_decoding_backend).backend_name
26372640
if backend not in valid_guided_backends:
26382641
raise ValueError(f"Invalid guided_decoding_backend '{backend},"
26392642
f"must be one of {valid_guided_backends}")

vllm/engine/arg_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -373,14 +373,17 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
373373
'--guided-decoding-backend',
374374
type=str,
375375
default='xgrammar',
376-
choices=['outlines', 'lm-format-enforcer', 'xgrammar'],
377376
help='Which engine will be used for guided decoding'
378377
' (JSON schema / regex etc) by default. Currently support '
379378
'https://github.com/outlines-dev/outlines, '
380379
'https://github.com/mlc-ai/xgrammar, and '
381380
'https://github.com/noamgat/lm-format-enforcer.'
382381
' Can be overridden per request via guided_decoding_backend'
383-
' parameter.')
382+
' parameter.\n'
383+
'Backend-sepcific options can be supplied in a comma-separated '
384+
'list following a colon after the backend name. Valid backends and '
385+
'all available options are: [xgrammar:no-fallback, '
386+
'outlines:no-fallback, lm-format-enforcer:no-fallback]')
384387
parser.add_argument(
385388
'--logits-processor-pattern',
386389
type=nullable_str,

vllm/model_executor/guided_decoding/__init__.py

Lines changed: 44 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -22,47 +22,56 @@
2222

2323
def maybe_backend_fallback(
2424
guided_params: GuidedDecodingParams) -> GuidedDecodingParams:
25+
26+
def fallback_or_error(guided_params: GuidedDecodingParams, message: str,
27+
fallback: str) -> None:
28+
"""Change the backend to the specified fallback with a warning log,
29+
or raise a ValueError if the `no-fallback` option is specified."""
30+
if guided_params.no_fallback():
31+
raise ValueError(message)
32+
33+
logger.warning("%s Falling back to use %s instead.", message, fallback)
34+
guided_params.backend = fallback
35+
2536
# lm-format-enforce doesn't support grammar, fallback to xgrammar
26-
if guided_params.backend == "lm-format-enforcer":
37+
if guided_params.backend_name == "lm-format-enforcer":
2738
if guided_params.grammar is not None:
28-
logger.warning(
29-
"lm-format-enforcer does not support grammar guided decoding. "
30-
"Falling back to use xgrammar instead.")
31-
guided_params.backend = "xgrammar"
39+
fallback_or_error(
40+
guided_params,
41+
"lm-format-enforcer does not support grammar guided decoding.",
42+
"xgrammar")
3243

3344
# lm-format-enforcer doesn't support some JSON schema features
3445
elif (guided_params.json is not None
3546
and has_lmf_unsupported_json_features(guided_params.json)):
36-
logger.warning(
47+
fallback_or_error(
48+
guided_params,
3749
"lm-format-enforcer does not support advanced JSON schema "
38-
"features like patterns or numeric ranges. "
39-
"Falling back to use outlines instead.")
40-
guided_params.backend = "outlines"
50+
"features like patterns or numeric ranges.", "outlines")
4151

42-
if guided_params.backend == "xgrammar":
52+
if guided_params.backend_name == "xgrammar":
4353
from vllm.model_executor.guided_decoding.xgrammar_decoding import (
4454
xgr_installed)
4555
# xgrammar only has x86 wheels for linux, fallback to outlines
4656
from vllm.platforms import current_platform
4757
if current_platform.get_cpu_architecture() is not CpuArchEnum.X86:
48-
logger.warning("xgrammar is only supported on x86 CPUs. "
49-
"Falling back to use outlines instead.")
50-
guided_params.backend = "outlines"
58+
fallback_or_error(guided_params,
59+
"xgrammar is only supported on x86 CPUs.",
60+
"outlines")
5161

5262
# xgrammar doesn't support regex, fallback to outlines
5363
if guided_params.regex is not None:
54-
logger.warning("xgrammar does not support regex guided decoding. "
55-
"Falling back to use outlines instead.")
56-
guided_params.backend = "outlines"
64+
fallback_or_error(
65+
guided_params,
66+
"xgrammar does not support regex guided decoding.", "outlines")
5767

5868
# xgrammar doesn't support some JSON schema features
5969
elif (guided_params.json is not None
6070
and has_xgrammar_unsupported_json_features(guided_params.json)):
61-
logger.warning(
71+
fallback_or_error(
72+
guided_params,
6273
"xgrammar does not support advanced JSON schema features like "
63-
"patterns or numeric ranges. "
64-
"Falling back to use outlines instead.")
65-
guided_params.backend = "outlines"
74+
"enums, patterns or numeric ranges.", "outlines")
6675

6776
# xgrammar only supports GBNF grammars, so we must convert Lark.
6877
# We must check if the grammar is likely Lark and if that
@@ -72,25 +81,23 @@ def maybe_backend_fallback(
7281
try:
7382
convert_lark_to_gbnf(guided_params.grammar)
7483
except Exception:
75-
logger.warning(
84+
fallback_or_error(
85+
guided_params,
7686
"xgrammar does not support Lark grammars and the "
77-
"grammar failed to convert to GBNF. "
78-
"Falling back to use outlines instead.")
79-
guided_params.backend = "outlines"
87+
"grammar failed to convert to GBNF.", "outlines")
8088

8189
# If the xgrammar module cannot be imported successfully,
8290
# we should still allow users to use guided decoding with a fallback.
8391
elif not xgr_installed:
84-
logger.warning("xgrammar module cannot be imported successfully. "
85-
"Falling back to use outlines instead.")
86-
guided_params.backend = "outlines"
92+
fallback_or_error(
93+
guided_params,
94+
"xgrammar module cannot be imported successfully.", "outlines")
8795

88-
if (guided_params.backend == "outlines"
96+
if (guided_params.backend_name == "outlines"
8997
and guided_params.json_object is not None):
9098
# outlines doesn't support json_object, fallback to xgrammar
91-
logger.warning("outlines does not support json_object. "
92-
"Falling back to use xgrammar instead.")
93-
guided_params.backend = "xgrammar"
99+
fallback_or_error(guided_params,
100+
"outlines does not support json_object.", "xgrammar")
94101

95102
return guided_params
96103

@@ -100,18 +107,18 @@ async def get_guided_decoding_logits_processor(
100107
model_config: ModelConfig) -> LogitsProcessor | None:
101108
guided_params = maybe_backend_fallback(guided_params)
102109
# CFG grammar not supported by LMFE, so we use outlines instead
103-
if guided_params.backend == 'outlines':
110+
if guided_params.backend_name == 'outlines':
104111
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
105112
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
106113
get_outlines_guided_decoding_logits_processor)
107114
return await get_outlines_guided_decoding_logits_processor(
108115
guided_params, tokenizer)
109-
if guided_params.backend == 'lm-format-enforcer':
116+
if guided_params.backend_name == 'lm-format-enforcer':
110117
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
111118
get_local_lm_format_enforcer_guided_decoding_logits_processor)
112119
return get_local_lm_format_enforcer_guided_decoding_logits_processor(
113120
guided_params, tokenizer)
114-
if guided_params.backend == 'xgrammar':
121+
if guided_params.backend_name == 'xgrammar':
115122
from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa
116123
get_local_xgrammar_guided_decoding_logits_processor)
117124
return get_local_xgrammar_guided_decoding_logits_processor(
@@ -127,18 +134,18 @@ def get_local_guided_decoding_logits_processor(
127134
model_config: ModelConfig) -> LogitsProcessor | None:
128135
guided_params = maybe_backend_fallback(guided_params)
129136
# CFG grammar not supported by LMFE, so we use outlines instead
130-
if guided_params.backend == 'outlines':
137+
if guided_params.backend_name == 'outlines':
131138
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
132139
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
133140
get_local_outlines_guided_decoding_logits_processor)
134141
return get_local_outlines_guided_decoding_logits_processor(
135142
guided_params, tokenizer)
136-
if guided_params.backend == 'lm-format-enforcer':
143+
if guided_params.backend_name == 'lm-format-enforcer':
137144
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
138145
get_local_lm_format_enforcer_guided_decoding_logits_processor)
139146
return get_local_lm_format_enforcer_guided_decoding_logits_processor(
140147
guided_params, tokenizer)
141-
if guided_params.backend == 'xgrammar':
148+
if guided_params.backend_name == 'xgrammar':
142149
from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa
143150
get_local_xgrammar_guided_decoding_logits_processor)
144151
return get_local_xgrammar_guided_decoding_logits_processor(

vllm/sampling_params.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,25 @@ def from_optional(
6464
whitespace_pattern=whitespace_pattern,
6565
)
6666

67+
@property
68+
def backend_name(self) -> str:
69+
"""Return the backend name without any options.
70+
71+
For example if the backend is "xgrammar:no-fallback", returns "xgrammar"
72+
"""
73+
return (self.backend or "").split(":")[0]
74+
75+
def backend_options(self) -> List[str]:
76+
"""Return the backend options as a list of strings."""
77+
if not self.backend or ":" not in self.backend:
78+
return []
79+
return self.backend.split(":")[1].split(",")
80+
81+
def no_fallback(self) -> bool:
82+
"""Returns True if the "no-fallback" option is supplied for the guided
83+
decoding backend"""
84+
return "no-fallback" in self.backend_options()
85+
6786
def __post_init__(self):
6887
"""Validate that some fields are mutually exclusive."""
6988
guide_count = sum([

0 commit comments

Comments
 (0)