Skip to content

Simplify (and fix) passing of guided decoding backend options #17008

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
360026f
Split `guided_decoding_backend` into `guided_decoding_backend` and `g…
hmellor Apr 22, 2025
0390b6f
Typo
hmellor Apr 22, 2025
dc9f57b
Merge branch 'main' into split-guided-decoding-backend-options
hmellor Apr 24, 2025
49b89e5
Merge branch 'main' into split-guided-decoding-backend-options
hmellor Apr 25, 2025
8a1ed06
Merge branch 'main' into split-guided-decoding-backend-options
hmellor Apr 28, 2025
7a660b0
Merge branch 'main' into split-guided-decoding-backend-options
hmellor Apr 28, 2025
592ff3d
Fix typo
hmellor Apr 28, 2025
2979c31
Use bool flags instead of cramming many flags into a dict
hmellor Apr 28, 2025
8a367ba
Fix word order in args
hmellor Apr 28, 2025
b9d4074
Fix missing arg
hmellor Apr 28, 2025
0884458
Enforce disable additional properties only works for guidance
hmellor Apr 28, 2025
4ebe137
Missed the EngineArgs field
hmellor Apr 28, 2025
05ab20b
Add backward compatible deprecated `guided_decoding_backend`
hmellor Apr 28, 2025
a34a434
Fix incorrect attribute in xgrammar
hmellor Apr 28, 2025
ae2f1c3
Fix test parameters
hmellor Apr 29, 2025
666ba39
Merge `Literal`s with `Literal` not `Union`
hmellor Apr 29, 2025
6df5dbb
Enforce that `Literal`s are merged with `Literal` not `Union`
hmellor Apr 29, 2025
08d5e20
Add tests for `config` decorator
hmellor Apr 29, 2025
92300b6
Create new helper function to handle sequences of literals
hmellor Apr 29, 2025
88c1479
Add test for literal to kwarg
hmellor Apr 29, 2025
4037013
Add test cases for `list[Literal]` and `Literal[Literal, Literal]`
hmellor Apr 29, 2025
0ecf76e
Fix pre-commit
hmellor Apr 29, 2025
8beb8df
Respond to comment
hmellor Apr 29, 2025
e44135c
Merge branch 'main' into split-guided-decoding-backend-options
hmellor Apr 29, 2025
008b037
Merge branch 'config-literal-handling' into split-guided-decoding-bac…
hmellor Apr 29, 2025
39557d7
Merge branch 'main' into split-guided-decoding-backend-options
hmellor Apr 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ def extra_backend_options_completion(client: OpenAI, model: str):
"[email protected]\n")

try:
# The no-fallback option forces vLLM to use xgrammar, so when it fails
# you get a 400 with the reason why
# The guided_decoding_disable_fallback option forces vLLM to use
# xgrammar, so when it fails you get a 400 with the reason why
completion = client.chat.completions.create(
model=model,
messages=[{
Expand All @@ -123,7 +123,8 @@ def extra_backend_options_completion(client: OpenAI, model: str):
extra_body={
"guided_regex": r"\w+@\w+\.com\n",
"stop": ["\n"],
"guided_decoding_backend": "xgrammar:no-fallback"
"guided_decoding_backend": "xgrammar",
"guided_decoding_disable_fallback": True,
},
)
return completion.choices[0].message.content
Expand Down
193 changes: 117 additions & 76 deletions tests/entrypoints/llm/test_guided_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@

MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
GUIDED_DECODING_BACKENDS = [
"outlines",
"lm-format-enforcer",
"xgrammar:disable-any-whitespace",
"guidance:disable-any-whitespace",
# (backend, disable_any_whitespace),
("outlines", False),
("lm-format-enforcer", False),
("xgrammar", True),
("guidance", True),
]


Expand All @@ -36,13 +37,17 @@ def llm():


@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
def test_guided_regex(sample_regex, llm, guided_decoding_backend: str):
sampling_params = SamplingParams(temperature=0.8,
top_p=0.95,
guided_decoding=GuidedDecodingParams(
regex=sample_regex,
backend=guided_decoding_backend))
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS)
def test_guided_regex(sample_regex, llm, guided_decoding_backend: str,
disable_any_whitespace: bool):
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
guided_decoding=GuidedDecodingParams(
regex=sample_regex,
backend=guided_decoding_backend,
disable_any_whitespace=disable_any_whitespace))
outputs = llm.generate(prompts=[
f"Give an example IPv4 address with this regex: {sample_regex}"
] * 2,
Expand All @@ -62,14 +67,18 @@ def test_guided_regex(sample_regex, llm, guided_decoding_backend: str):


@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS)
def test_guided_json_completion(sample_json_schema, llm,
guided_decoding_backend: str):
sampling_params = SamplingParams(temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
json=sample_json_schema,
backend=guided_decoding_backend))
guided_decoding_backend: str,
disable_any_whitespace: bool):
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
json=sample_json_schema,
backend=guided_decoding_backend,
disable_any_whitespace=disable_any_whitespace))
outputs = llm.generate(prompts=[
f"Give an example JSON for an employee profile "
f"that fits this schema: {sample_json_schema}"
Expand All @@ -92,14 +101,18 @@ def test_guided_json_completion(sample_json_schema, llm,


@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS)
def test_guided_complex_json_completion(sample_complex_json_schema, llm,
guided_decoding_backend: str):
sampling_params = SamplingParams(temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
json=sample_complex_json_schema,
backend=guided_decoding_backend))
guided_decoding_backend: str,
disable_any_whitespace: bool):
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
json=sample_complex_json_schema,
backend=guided_decoding_backend,
disable_any_whitespace=disable_any_whitespace))
outputs = llm.generate(prompts=[
f"Give an example JSON for an assignment grade "
f"that fits this schema: {sample_complex_json_schema}"
Expand All @@ -123,14 +136,18 @@ def test_guided_complex_json_completion(sample_complex_json_schema, llm,


@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS)
def test_guided_definition_json_completion(sample_definition_json_schema, llm,
guided_decoding_backend: str):
sampling_params = SamplingParams(temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
json=sample_definition_json_schema,
backend=guided_decoding_backend))
guided_decoding_backend: str,
disable_any_whitespace: bool):
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
json=sample_definition_json_schema,
backend=guided_decoding_backend,
disable_any_whitespace=disable_any_whitespace))
outputs = llm.generate(prompts=[
f"Give an example JSON for solving 8x + 7 = -23 "
f"that fits this schema: {sample_definition_json_schema}"
Expand All @@ -154,14 +171,18 @@ def test_guided_definition_json_completion(sample_definition_json_schema, llm,


@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS)
def test_guided_enum_json_completion(sample_enum_json_schema, llm,
guided_decoding_backend: str):
sampling_params = SamplingParams(temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
json=sample_enum_json_schema,
backend=guided_decoding_backend))
guided_decoding_backend: str,
disable_any_whitespace: bool):
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
json=sample_enum_json_schema,
backend=guided_decoding_backend,
disable_any_whitespace=disable_any_whitespace))
outputs = llm.generate(prompts=[
"Create a bug report JSON that fits this schema: "
f"{sample_enum_json_schema}. Make it for a high priority critical bug."
Expand Down Expand Up @@ -195,14 +216,18 @@ def test_guided_enum_json_completion(sample_enum_json_schema, llm,


@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS)
def test_guided_choice_completion(sample_guided_choice, llm,
guided_decoding_backend: str):
sampling_params = SamplingParams(temperature=0.8,
top_p=0.95,
guided_decoding=GuidedDecodingParams(
choice=sample_guided_choice,
backend=guided_decoding_backend))
guided_decoding_backend: str,
disable_any_whitespace: bool):
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
guided_decoding=GuidedDecodingParams(
choice=sample_guided_choice,
backend=guided_decoding_backend,
disable_any_whitespace=disable_any_whitespace))
outputs = llm.generate(
prompts="The best language for type-safe systems programming is ",
sampling_params=sampling_params,
Expand All @@ -221,15 +246,19 @@ def test_guided_choice_completion(sample_guided_choice, llm,


@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS)
def test_guided_grammar(sample_sql_statements, llm,
guided_decoding_backend: str):
sampling_params = SamplingParams(temperature=0.8,
top_p=0.95,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
grammar=sample_sql_statements,
backend=guided_decoding_backend))
guided_decoding_backend: str,
disable_any_whitespace: bool):
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
grammar=sample_sql_statements,
backend=guided_decoding_backend,
disable_any_whitespace=disable_any_whitespace))
outputs = llm.generate(
prompts=("Generate a sql state that select col_1 from "
"table_1 where it is equals to 1"),
Expand Down Expand Up @@ -300,7 +329,8 @@ def test_disable_guided_decoding_fallback(sample_regex, llm):
top_p=0.95,
guided_decoding=GuidedDecodingParams(
json=unsupported_json,
backend="xgrammar:no-fallback"))
backend="xgrammar",
disable_fallback=True))

with pytest.raises(
ValueError,
Expand All @@ -312,14 +342,18 @@ def test_disable_guided_decoding_fallback(sample_regex, llm):


@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
def test_guided_json_object(llm, guided_decoding_backend: str):
sampling_params = SamplingParams(temperature=1.0,
max_tokens=100,
n=2,
guided_decoding=GuidedDecodingParams(
json_object=True,
backend=guided_decoding_backend))
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS)
def test_guided_json_object(llm, guided_decoding_backend: str,
disable_any_whitespace: bool):
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=100,
n=2,
guided_decoding=GuidedDecodingParams(
json_object=True,
backend=guided_decoding_backend,
disable_any_whitespace=disable_any_whitespace))

outputs = llm.generate(
prompts=("Generate a JSON object with curly braces for a person with "
Expand All @@ -337,7 +371,7 @@ def test_guided_json_object(llm, guided_decoding_backend: str):
print(generated_text)
assert generated_text is not None

if 'disable-any-whitespace' in guided_decoding_backend:
if disable_any_whitespace:
assert "\n" not in generated_text

# Parse to verify it is valid JSON
Expand All @@ -359,14 +393,18 @@ class CarDescription(BaseModel):


@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str):
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS)
def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str,
disable_any_whitespace: bool):
json_schema = CarDescription.model_json_schema()
sampling_params = SamplingParams(temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
json=json_schema,
backend=guided_decoding_backend))
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
json=json_schema,
backend=guided_decoding_backend,
disable_any_whitespace=disable_any_whitespace))
outputs = llm.generate(
prompts="Generate a JSON with the brand, model and car_type of"
"the most iconic car from the 90's",
Expand Down Expand Up @@ -466,8 +504,12 @@ def test_guidance_no_additional_properties(llm):
"large JSON object with key-value pairs a1=b1, a2=b2, ..., a20=b20"
"<|im_end|>\n<|im_start|>assistant\n")

def generate_with_backend(backend):
guided_params = GuidedDecodingParams(json=schema, backend=backend)
def generate_with_backend(backend, disable_additional_properties):
guided_params = GuidedDecodingParams(
json=schema,
backend=backend,
disable_any_whitespace=True,
disable_additional_properties=disable_additional_properties)
sampling_params = SamplingParams(temperature=0,
max_tokens=256,
guided_decoding=guided_params)
Expand All @@ -481,7 +523,7 @@ def generate_with_backend(backend):
jsonschema.validate(instance=parsed_json, schema=schema)
return parsed_json

base_generated = generate_with_backend('guidance:disable-any-whitespace')
base_generated = generate_with_backend("guidance", False)
assert "a1" in base_generated
assert "a2" in base_generated
assert "a3" in base_generated
Expand All @@ -490,8 +532,7 @@ def generate_with_backend(backend):
assert "a5" in base_generated
assert "a6" in base_generated

generated = generate_with_backend(
'guidance:no-additional-properties,disable-any-whitespace')
generated = generate_with_backend("guidance", True)
assert "a1" in generated
assert "a2" in generated
assert "a3" in generated
Expand Down
15 changes: 9 additions & 6 deletions tests/model_executor/test_guided_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,12 +202,15 @@ def test_multiple_guided_options_not_allowed(sample_json_schema, sample_regex):

def test_guided_decoding_backend_options():
"""Test backend-specific options"""
params = GuidedDecodingParams(
backend="xgrammar:option-1,option-2,option-3")
assert params.backend_options() == ["option-1", "option-2", "option-3"]

no_fallback = GuidedDecodingParams(backend="xgrammar:option-1,no-fallback")
assert no_fallback.no_fallback()
with pytest.warns(DeprecationWarning):
guided_decoding_params = GuidedDecodingParams(
backend=
"xgrammar:no-fallback,disable-any-whitespace,no-additional-properties"
)
assert guided_decoding_params.backend == "xgrammar"
assert guided_decoding_params.disable_fallback
assert guided_decoding_params.disable_any_whitespace
assert guided_decoding_params.disable_additional_properties


def test_pickle_xgrammar_tokenizer_data():
Expand Down
Loading