Skip to content

[Frontend] Add backend-specific options for guided decoding #13505

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Feb 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/features/structured_outputs.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ The following parameters are supported, which must be added as extra parameters:
- `guided_json`: the output will follow the JSON schema.
- `guided_grammar`: the output will follow the context free grammar.
- `guided_whitespace_pattern`: used to override the default whitespace pattern for guided json decoding.
- `guided_decoding_backend`: used to select the guided decoding backend to use.
- `guided_decoding_backend`: used to select the guided decoding backend to use. Additional backend-specific options can be supplied in a comma separated list following a colon after the backend name. For example `"xgrammar:no-fallback"` will not allow vLLM to fallback to a different backend on error.

You can see the complete list of supported parameters on the [OpenAI-Compatible Server](#openai-compatible-server)page.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from enum import Enum

from openai import OpenAI
from openai import BadRequestError, OpenAI
from pydantic import BaseModel

client = OpenAI(
Expand Down Expand Up @@ -94,3 +94,26 @@ class CarDescription(BaseModel):
extra_body={"guided_grammar": simplified_sql_grammar},
)
print(completion.choices[0].message.content)

# Extra backend options
prompt = ("Generate an email address for Alan Turing, who works in Enigma."
"End in .com and new line. Example result:"
"[email protected]\n")

try:
# The no-fallback option forces vLLM to use xgrammar, so when it fails
# you get a 400 with the reason why
completion = client.chat.completions.create(
model="Qwen/Qwen2.5-3B-Instruct",
messages=[{
"role": "user",
"content": prompt,
}],
extra_body={
"guided_regex": "\w+@\w+\.com\n",
"stop": ["\n"],
"guided_decoding_backend": "xgrammar:no-fallback"
},
)
except BadRequestError as e:
print("This error is expected:", e)
16 changes: 16 additions & 0 deletions tests/entrypoints/llm/test_guided_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,22 @@ def test_validation_against_both_guided_decoding_options(sample_regex, llm):
guided_options_request=dict(guided_regex=sample_regex))


@pytest.mark.skip_global_cleanup
def test_disable_guided_decoding_fallback(sample_regex, llm):
sampling_params = SamplingParams(temperature=0.8,
top_p=0.95,
guided_decoding=GuidedDecodingParams(
regex=sample_regex,
backend="xgrammar:no-fallback"))

with pytest.raises(
ValueError,
match="xgrammar does not support regex guided decoding"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will soon! :-). #13228

This is totally fine, though. I can tweak this once we're able to turn it on.

llm.generate(prompts="This should fail",
sampling_params=sampling_params,
use_tqdm=True)


@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
def test_guided_json_object(llm, guided_decoding_backend: str):
Expand Down
10 changes: 10 additions & 0 deletions tests/model_executor/test_guided_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,16 @@ def test_multiple_guided_options_not_allowed(sample_json_schema, sample_regex):
GuidedDecodingParams(json=sample_json_schema, grammar="test grammar")


def test_guided_decoding_backend_options():
"""Test backend-specific options"""
params = GuidedDecodingParams(
backend="xgrammar:option-1,option-2,option-3")
assert params.backend_options() == ["option-1", "option-2", "option-3"]

no_fallback = GuidedDecodingParams(backend="xgrammar:option-1,no-fallback")
assert no_fallback.no_fallback()


def test_pickle_xgrammar_tokenizer_data():

# TODO: move to another test file for xgrammar
Expand Down
5 changes: 4 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
get_quantization_config)
from vllm.model_executor.models import ModelRegistry
from vllm.platforms import CpuArchEnum
from vllm.sampling_params import GuidedDecodingParams
from vllm.tracing import is_otel_available, otel_import_error_traceback
from vllm.transformers_utils.config import (
ConfigFormat, get_config, get_hf_image_processor_config,
Expand Down Expand Up @@ -2603,7 +2604,9 @@ def compute_hash(self) -> str:

def __post_init__(self):
valid_guided_backends = ['outlines', 'lm-format-enforcer', 'xgrammar']
backend = self.guided_decoding_backend

backend = GuidedDecodingParams(
backend=self.guided_decoding_backend).backend_name
if backend not in valid_guided_backends:
raise ValueError(f"Invalid guided_decoding_backend '{backend},"
f"must be one of {valid_guided_backends}")
Expand Down
7 changes: 5 additions & 2 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,14 +371,17 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
'--guided-decoding-backend',
type=str,
default='xgrammar',
choices=['outlines', 'lm-format-enforcer', 'xgrammar'],
help='Which engine will be used for guided decoding'
' (JSON schema / regex etc) by default. Currently support '
'https://github.com/outlines-dev/outlines, '
'https://github.com/mlc-ai/xgrammar, and '
'https://github.com/noamgat/lm-format-enforcer.'
' Can be overridden per request via guided_decoding_backend'
' parameter.')
' parameter.\n'
'Backend-sepcific options can be supplied in a comma-separated '
'list following a colon after the backend name. Valid backends and '
'all available options are: [xgrammar:no-fallback, '
'outlines:no-fallback, lm-format-enforcer:no-fallback]')
parser.add_argument(
'--logits-processor-pattern',
type=nullable_str,
Expand Down
81 changes: 44 additions & 37 deletions vllm/model_executor/guided_decoding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,47 +22,56 @@

def maybe_backend_fallback(
guided_params: GuidedDecodingParams) -> GuidedDecodingParams:

def fallback_or_error(guided_params: GuidedDecodingParams, message: str,
fallback: str) -> None:
"""Change the backend to the specified fallback with a warning log,
or raise a ValueError if the `no-fallback` option is specified."""
if guided_params.no_fallback():
raise ValueError(message)

logger.warning("%s Falling back to use %s instead.", message, fallback)
guided_params.backend = fallback

# lm-format-enforce doesn't support grammar, fallback to xgrammar
if guided_params.backend == "lm-format-enforcer":
if guided_params.backend_name == "lm-format-enforcer":
if guided_params.grammar is not None:
logger.warning(
"lm-format-enforcer does not support grammar guided decoding. "
"Falling back to use xgrammar instead.")
guided_params.backend = "xgrammar"
fallback_or_error(
guided_params,
"lm-format-enforcer does not support grammar guided decoding.",
"xgrammar")

# lm-format-enforcer doesn't support some JSON schema features
elif (guided_params.json is not None
and has_lmf_unsupported_json_features(guided_params.json)):
logger.warning(
fallback_or_error(
guided_params,
"lm-format-enforcer does not support advanced JSON schema "
"features like patterns or numeric ranges. "
"Falling back to use outlines instead.")
guided_params.backend = "outlines"
"features like patterns or numeric ranges.", "outlines")

if guided_params.backend == "xgrammar":
if guided_params.backend_name == "xgrammar":
from vllm.model_executor.guided_decoding.xgrammar_decoding import (
xgr_installed)
# xgrammar only has x86 wheels for linux, fallback to outlines
from vllm.platforms import current_platform
if current_platform.get_cpu_architecture() is not CpuArchEnum.X86:
logger.warning("xgrammar is only supported on x86 CPUs. "
"Falling back to use outlines instead.")
guided_params.backend = "outlines"
fallback_or_error(guided_params,
"xgrammar is only supported on x86 CPUs.",
"outlines")

# xgrammar doesn't support regex, fallback to outlines
if guided_params.regex is not None:
logger.warning("xgrammar does not support regex guided decoding. "
"Falling back to use outlines instead.")
guided_params.backend = "outlines"
fallback_or_error(
guided_params,
"xgrammar does not support regex guided decoding.", "outlines")

# xgrammar doesn't support some JSON schema features
elif (guided_params.json is not None
and has_xgrammar_unsupported_json_features(guided_params.json)):
logger.warning(
fallback_or_error(
guided_params,
"xgrammar does not support advanced JSON schema features like "
"patterns or numeric ranges. "
"Falling back to use outlines instead.")
guided_params.backend = "outlines"
"enums, patterns or numeric ranges.", "outlines")

# xgrammar only supports GBNF grammars, so we must convert Lark.
# We must check if the grammar is likely Lark and if that
Expand All @@ -72,25 +81,23 @@ def maybe_backend_fallback(
try:
convert_lark_to_gbnf(guided_params.grammar)
except Exception:
logger.warning(
fallback_or_error(
guided_params,
"xgrammar does not support Lark grammars and the "
"grammar failed to convert to GBNF. "
"Falling back to use outlines instead.")
guided_params.backend = "outlines"
"grammar failed to convert to GBNF.", "outlines")

# If the xgrammar module cannot be imported successfully,
# we should still allow users to use guided decoding with a fallback.
elif not xgr_installed:
logger.warning("xgrammar module cannot be imported successfully. "
"Falling back to use outlines instead.")
guided_params.backend = "outlines"
fallback_or_error(
guided_params,
"xgrammar module cannot be imported successfully.", "outlines")

if (guided_params.backend == "outlines"
if (guided_params.backend_name == "outlines"
and guided_params.json_object is not None):
# outlines doesn't support json_object, fallback to xgrammar
logger.warning("outlines does not support json_object. "
"Falling back to use xgrammar instead.")
guided_params.backend = "xgrammar"
fallback_or_error(guided_params,
"outlines does not support json_object.", "xgrammar")

return guided_params

Expand All @@ -100,18 +107,18 @@ async def get_guided_decoding_logits_processor(
model_config: ModelConfig) -> LogitsProcessor | None:
guided_params = maybe_backend_fallback(guided_params)
# CFG grammar not supported by LMFE, so we use outlines instead
if guided_params.backend == 'outlines':
if guided_params.backend_name == 'outlines':
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
get_outlines_guided_decoding_logits_processor)
return await get_outlines_guided_decoding_logits_processor(
guided_params, tokenizer)
if guided_params.backend == 'lm-format-enforcer':
if guided_params.backend_name == 'lm-format-enforcer':
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
get_local_lm_format_enforcer_guided_decoding_logits_processor)
return get_local_lm_format_enforcer_guided_decoding_logits_processor(
guided_params, tokenizer)
if guided_params.backend == 'xgrammar':
if guided_params.backend_name == 'xgrammar':
from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa
get_local_xgrammar_guided_decoding_logits_processor)
return get_local_xgrammar_guided_decoding_logits_processor(
Expand All @@ -127,18 +134,18 @@ def get_local_guided_decoding_logits_processor(
model_config: ModelConfig) -> LogitsProcessor | None:
guided_params = maybe_backend_fallback(guided_params)
# CFG grammar not supported by LMFE, so we use outlines instead
if guided_params.backend == 'outlines':
if guided_params.backend_name == 'outlines':
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
get_local_outlines_guided_decoding_logits_processor)
return get_local_outlines_guided_decoding_logits_processor(
guided_params, tokenizer)
if guided_params.backend == 'lm-format-enforcer':
if guided_params.backend_name == 'lm-format-enforcer':
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
get_local_lm_format_enforcer_guided_decoding_logits_processor)
return get_local_lm_format_enforcer_guided_decoding_logits_processor(
guided_params, tokenizer)
if guided_params.backend == 'xgrammar':
if guided_params.backend_name == 'xgrammar':
from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa
get_local_xgrammar_guided_decoding_logits_processor)
return get_local_xgrammar_guided_decoding_logits_processor(
Expand Down
19 changes: 19 additions & 0 deletions vllm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,25 @@ def from_optional(
whitespace_pattern=whitespace_pattern,
)

@property
def backend_name(self) -> str:
"""Return the backend name without any options.

For example if the backend is "xgrammar:no-fallback", returns "xgrammar"
"""
return (self.backend or "").split(":")[0]

def backend_options(self) -> List[str]:
"""Return the backend options as a list of strings."""
if not self.backend or ":" not in self.backend:
return []
return self.backend.split(":")[1].split(",")

def no_fallback(self) -> bool:
"""Returns True if the "no-fallback" option is supplied for the guided
decoding backend"""
return "no-fallback" in self.backend_options()

def __post_init__(self):
"""Validate that some fields are mutually exclusive."""
guide_count = sum([
Expand Down