Skip to content

[Frontend] Add backend-specific options for guided decoding #13505

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Feb 20, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions tests/entrypoints/llm/test_guided_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
import json
import re
import weakref
from unittest import mock

import jsonschema
import pytest

from vllm.distributed import cleanup_dist_env_and_memory
from vllm.entrypoints.llm import LLM
from vllm.model_executor import guided_decoding as guided_decoding_module
from vllm.outputs import RequestOutput
from vllm.sampling_params import GuidedDecodingParams, SamplingParams

Expand Down Expand Up @@ -277,6 +279,24 @@
guided_options_request=dict(guided_regex=sample_regex))


@pytest.mark.skip_global_cleanup
def test_disable_guided_decoding_fallback(sample_regex, llm):
sampling_params = SamplingParams(temperature=0.8,
top_p=0.95,
guided_decoding=GuidedDecodingParams(
regex=sample_regex,
backend="xgrammar"))

with mock.patch.object(guided_decoding_module,
"VLLM_DISABLE_GUIDED_DECODING_FALLBACK", True):
with pytest.raises(
ValueError,
match="xgrammar does not support regex guided decoding"):

Check failure on line 294 in tests/entrypoints/llm/test_guided_generate.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (SIM117)

tests/entrypoints/llm/test_guided_generate.py:290:5: SIM117 Use a single `with` statement with multiple contexts instead of nested `with` statements
llm.generate(prompts="This should fail",
sampling_params=sampling_params,
use_tqdm=True)


@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
def test_guided_json_object(llm, guided_decoding_backend: str):
Expand Down
3 changes: 3 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
VLLM_RAY_PER_WORKER_GPUS: float = 1.0
VLLM_RAY_BUNDLE_INDICES: str = ""
VLLM_CUDART_SO_PATH: Optional[str] = None
VLLM_DISABLE_GUIDED_DECODING_FALLBACK: bool = False


def get_default_cache_root():
Expand Down Expand Up @@ -585,6 +586,8 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
# specify the path through environment variable VLLM_CUDART_SO_PATH.
"VLLM_CUDART_SO_PATH":
lambda: os.getenv("VLLM_CUDART_SO_PATH", None),
"VLLM_DISABLE_GUIDED_DECODING_FALLBACK":
lambda: bool(int(os.getenv("VLLM_DISABLE_GUIDED_DECODING_FALLBACK", "0"))),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have we an established pattern on what should be config vs env variables? Why wouldn't this be in DecodingConfig? Maybe we could encode "don't fallback" in something like --guided-decoding-backend=outlines:nofallback if we were worried about a proliferation of CLI arguments.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm okay with either way, but I do think Mark's suggestion would be nicer. I like calling more attention to the --guided-decoding-backend argument if users want to be explicit about their backend

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i like the cli arg approach ... before seeing this comment I was thinking about another "backend" like xgrammar-only or something like that. xgrammar:nofallback leaves it open to a bit more flexibility to specify additional options if necessary later, like xgrammar:nofallback,json-any-whitespace to support the case covered in #12744

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have we an established pattern on what should be config vs env variables?

Yeah ... I keep thinking about this. It's going to be a big project, but we're due for significant cleanup here. I'd really like a system that supports both config files and command line args (and less env vars unless it's just an alternative for setting the same set of options).

... but I have no idea when that's going to feel like the most important thing to work on!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have we an established pattern on what should be config vs env variables?

Yeah ... I keep thinking about this.

Me too. I like to think that the environment variables change the 'behavior' of the system like using a deprecated | experimental | workaround feature. While config are the others 'common features' of the system that's up to the users to set or tune to their environment.

However there also something that makes sense to this discussion. When we set in the config we have a chance to log the system setup, like the log Initializing a V0 LLM engine (v%s) with config: [...] Sometimes it is tricky to get the exact setup of the system when we got a crash and the only thing that we get it is a stack trace (which may be truncated as well 😄) .

Probably we should prefer using args before envs , but when makes sense to use envs, we probably could log to the users (at least once) that vLLM has this feature on and the implications of that in the system.

Either way: I like the idea of --guided-decoding-backend=outlines:nofallback for this PR. And I'm pretty sure that it would be logged in the system initialization, which is nice for debugging purpose.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the discussion everybody 👍

I stuck this in the environment to avoid the proliferation of cli args, but I love the suggestion of encoding the fallback behavior in the name of the backend. Best of both worlds!

I'll update the implementation

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code is updated if y'all wanna take a second look 🙏

}

# end-env-vars-definition
Expand Down
64 changes: 36 additions & 28 deletions vllm/model_executor/guided_decoding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from typing import TYPE_CHECKING

from vllm.envs import VLLM_DISABLE_GUIDED_DECODING_FALLBACK
from vllm.logger import init_logger
from vllm.model_executor.guided_decoding.utils import (
convert_lark_to_gbnf, grammar_is_likely_lark,
Expand All @@ -22,47 +23,56 @@

def maybe_backend_fallback(
guided_params: GuidedDecodingParams) -> GuidedDecodingParams:

def fallback_or_error(guided_params: GuidedDecodingParams, message: str,
fallback: str) -> None:
"""Change the backend to the specified fallback with a warning log,
or raise a ValueError if guided decoding fallback is not enabled."""
if VLLM_DISABLE_GUIDED_DECODING_FALLBACK:
raise ValueError(message)

logger.warning("%s Falling back to use %s instead.", message, fallback)
guided_params.backend = fallback

# lm-format-enforce doesn't support grammar, fallback to xgrammar
if guided_params.backend == "lm-format-enforcer":
if guided_params.grammar is not None:
logger.warning(
"lm-format-enforcer does not support grammar guided decoding. "
"Falling back to use xgrammar instead.")
guided_params.backend = "xgrammar"
fallback_or_error(
guided_params,
"lm-format-enforcer does not support grammar guided decoding.",
"xgrammar")

# lm-format-enforcer doesn't support some JSON schema features
elif (guided_params.json is not None
and has_lmf_unsupported_json_features(guided_params.json)):
logger.warning(
fallback_or_error(
guided_params,
"lm-format-enforcer does not support advanced JSON schema "
"features like patterns or numeric ranges. "
"Falling back to use outlines instead.")
guided_params.backend = "outlines"
"features like patterns or numeric ranges.", "outlines")

if guided_params.backend == "xgrammar":
from vllm.model_executor.guided_decoding.xgrammar_decoding import (
xgr_installed)
# xgrammar only has x86 wheels for linux, fallback to outlines
from vllm.platforms import current_platform
if current_platform.get_cpu_architecture() is not CpuArchEnum.X86:
logger.warning("xgrammar is only supported on x86 CPUs. "
"Falling back to use outlines instead.")
guided_params.backend = "outlines"
fallback_or_error(guided_params,
"xgrammar is only supported on x86 CPUs.",
"outlines")

# xgrammar doesn't support regex, fallback to outlines
if guided_params.regex is not None:
logger.warning("xgrammar does not support regex guided decoding. "
"Falling back to use outlines instead.")
guided_params.backend = "outlines"
fallback_or_error(
guided_params,
"xgrammar does not support regex guided decoding.", "outlines")

# xgrammar doesn't support some JSON schema features
elif (guided_params.json is not None
and has_xgrammar_unsupported_json_features(guided_params.json)):
logger.warning(
fallback_or_error(
guided_params,
"xgrammar does not support advanced JSON schema features like "
"patterns or numeric ranges. "
"Falling back to use outlines instead.")
guided_params.backend = "outlines"
"enums, patterns or numeric ranges.", "outlines")

# xgrammar only supports GBNF grammars, so we must convert Lark.
# We must check if the grammar is likely Lark and if that
Expand All @@ -72,25 +82,23 @@ def maybe_backend_fallback(
try:
convert_lark_to_gbnf(guided_params.grammar)
except Exception:
logger.warning(
fallback_or_error(
guided_params,
"xgrammar does not support Lark grammars and the "
"grammar failed to convert to GBNF. "
"Falling back to use outlines instead.")
guided_params.backend = "outlines"
"grammar failed to convert to GBNF.", "outlines")

# If the xgrammar module cannot be imported successfully,
# we should still allow users to use guided decoding with a fallback.
elif not xgr_installed:
logger.warning("xgrammar module cannot be imported successfully. "
"Falling back to use outlines instead.")
guided_params.backend = "outlines"
fallback_or_error(
guided_params,
"xgrammar module cannot be imported successfully.", "outlines")

if (guided_params.backend == "outlines"
and guided_params.json_object is not None):
# outlines doesn't support json_object, fallback to xgrammar
logger.warning("outlines does not support json_object. "
"Falling back to use xgrammar instead.")
guided_params.backend = "xgrammar"
fallback_or_error(guided_params,
"outlines does not support json_object.", "xgrammar")

return guided_params

Expand Down
Loading