Skip to content

[Frontend] Support guidance:no-additional-properties for compatibility with xgrammar #15949

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 59 additions & 1 deletion tests/entrypoints/llm/test_guided_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,4 +383,62 @@ def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str):
assert generated_text is not None
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
output_json = json.loads(generated_text)
jsonschema.validate(instance=output_json, schema=json_schema)
jsonschema.validate(instance=output_json, schema=json_schema)


@pytest.mark.skip_global_cleanup
def test_guidance_no_additional_properties(llm):
schema = {
'type': 'object',
'properties': {
'a1': {
'type': 'string'
},
'a2': {
'type': 'string'
},
'a3': {
'type': 'string'
}
},
'required': ['a1', 'a2', 'a3'],
}

prompt = (
"<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a "
"helpful assistant.<|im_end|>\n<|im_start|>user\nPlease generate a "
"large JSON object with key-value pairs a1=b1, a2=b2, ..., a20=b20"
"<|im_end|>\n<|im_start|>assistant\n")

def generate_with_backend(backend):
guided_params = GuidedDecodingParams(json=schema, backend=backend)
sampling_params = SamplingParams(temperature=0,
max_tokens=256,
guided_decoding=guided_params)

outputs = llm.generate(prompts=prompt, sampling_params=sampling_params)
assert outputs is not None
generated_text = outputs[0].outputs[0].text
assert generated_text is not None
parsed_json = json.loads(generated_text)
assert isinstance(parsed_json, dict)
jsonschema.validate(instance=parsed_json, schema=schema)
return parsed_json

base_generated = generate_with_backend('guidance:disable-any-whitespace')
assert "a1" in base_generated
assert "a2" in base_generated
assert "a3" in base_generated
# by default additional keys are generated
assert "a4" in base_generated
assert "a5" in base_generated
assert "a6" in base_generated

generated = generate_with_backend(
'guidance:no-additional-properties,disable-any-whitespace')
assert "a1" in generated
assert "a2" in generated
assert "a3" in generated
assert "a4" not in generated
assert "a5" not in generated
assert "a6" not in generated
56 changes: 56 additions & 0 deletions tests/v1/entrypoints/llm/test_struct_output_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,3 +412,59 @@ def test_structured_output_auto_mode(
# Parse to verify it is valid JSON
parsed_json = json.loads(generated_text)
assert isinstance(parsed_json, dict)


@pytest.mark.skip_global_cleanup
def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("VLLM_USE_V1", "1")

backend = 'guidance:no-additional-properties,disable-any-whitespace'
llm = LLM(model="Qwen/Qwen2.5-1.5B-Instruct",
max_model_len=1024,
guided_decoding_backend=backend)

schema = {
'type': 'object',
'properties': {
'a1': {
'type': 'string'
},
'a2': {
'type': 'string'
},
'a3': {
'type': 'string'
}
},
'required': ['a1', 'a2', 'a3'],
}

prompt = (
"<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a "
"helpful assistant.<|im_end|>\n<|im_start|>user\nPlease generate a "
"large JSON object with key-value pairs a1=b1, a2=b2, ..., a20=b20"
"<|im_end|>\n<|im_start|>assistant\n")

def generate_with_backend(backend):
guided_params = GuidedDecodingParams(json=schema, backend=backend)
sampling_params = SamplingParams(temperature=0,
max_tokens=256,
guided_decoding=guided_params)

outputs = llm.generate(prompts=prompt, sampling_params=sampling_params)
assert outputs is not None
generated_text = outputs[0].outputs[0].text
assert generated_text is not None
parsed_json = json.loads(generated_text)
assert isinstance(parsed_json, dict)
jsonschema.validate(instance=parsed_json, schema=schema)
return parsed_json

generated = generate_with_backend(
'guidance:no-additional-properties,disable-any-whitespace')
assert "a1" in generated
assert "a2" in generated
assert "a3" in generated
assert "a4" not in generated
assert "a5" not in generated
assert "a6" not in generated
2 changes: 1 addition & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3091,7 +3091,7 @@ def get_served_model_name(model: str,


GuidedDecodingBackendV0 = Literal["auto", "outlines", "lm-format-enforcer",
"xgrammar"]
"xgrammar", "guidance"]
GuidedDecodingBackendV1 = Literal["auto", "xgrammar", "guidance"]


Expand Down
18 changes: 9 additions & 9 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from vllm import version
from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
ConfigFormat, ConfigType, DecodingConfig, Device,
DeviceConfig, DistributedExecutorBackend, HfOverrides,
DeviceConfig, DistributedExecutorBackend,
GuidedDecodingBackendV1, HfOverrides,
KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig,
ModelConfig, ModelImpl, MultiModalConfig,
ObservabilityConfig, ParallelConfig, PoolerConfig,
Expand Down Expand Up @@ -1370,14 +1371,13 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
recommend_to_remove=True)
return False

# Xgrammar and Guidance are supported.
SUPPORTED_GUIDED_DECODING = [
"xgrammar", "xgrammar:disable-any-whitespace", "guidance",
"guidance:disable-any-whitespace", "auto"
]
if self.guided_decoding_backend not in SUPPORTED_GUIDED_DECODING:
_raise_or_fallback(feature_name="--guided-decoding-backend",
recommend_to_remove=False)
# remove backend options when doing this check
if self.guided_decoding_backend.split(':')[0] \
not in get_args(GuidedDecodingBackendV1):
_raise_or_fallback(
feature_name=
f"--guided-decoding-backend={self.guided_decoding_backend}",
recommend_to_remove=False)
return False

# Need at least Ampere for now (FA support required).
Expand Down
15 changes: 13 additions & 2 deletions vllm/model_executor/guided_decoding/guidance_decoding.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
import json
from re import escape as regex_escape

import llguidance
Expand All @@ -7,6 +8,8 @@
from vllm.model_executor.guided_decoding.guidance_logits_processors import (
GuidanceLogitsProcessor)
from vllm.sampling_params import GuidedDecodingParams
from vllm.v1.structured_output.backend_guidance import (
process_for_additional_properties)


def get_local_guidance_guided_decoding_logits_processor(
Expand All @@ -20,9 +23,17 @@ def get_local_guidance_guided_decoding_logits_processor(
grm = ""
any_whitespace = 'disable-any-whitespace' not in \
guided_params.backend_options()
if guided_params.json:
if (guide_json := guided_params.json) is not None:
# Optionally set additionalProperties to False at the top-level
# By default, other backends do not allow additional top-level
# properties, so this makes guidance more similar to other backends
if 'no-additional-properties' in guided_params.backend_options():
if not isinstance(guide_json, str):
guide_json = json.dumps(guide_json)
guide_json = process_for_additional_properties(guide_json)

grm = llguidance.LLMatcher.grammar_from_json_schema(
guided_params.json,
guide_json,
overrides={"whitespace_pattern": guided_params.whitespace_pattern},
defaults={
"whitespace_flexible": any_whitespace,
Expand Down
8 changes: 0 additions & 8 deletions vllm/v1/engine/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,15 +145,7 @@ def _validate_structured_output(self, params: SamplingParams) -> None:
if not params.guided_decoding or not self.decoding_config:
return

supported_backends = [
"xgrammar", "xgrammar:disable-any-whitespace", "guidance",
"guidance:disable-any-whitespace", "auto"
]

engine_level_backend = self.decoding_config.guided_decoding_backend
if engine_level_backend not in supported_backends:
raise ValueError(f"Only {supported_backends} structured output is "
"supported in V1.")
if params.guided_decoding.backend:
# Request-level backend selection is not supported in V1.
# The values may differ if `params` is reused and was set
Expand Down
60 changes: 51 additions & 9 deletions vllm/v1/structured_output/backend_guidance.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
# SPDX-License-Identifier: Apache-2.0

import copy
import json
import os
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Any, Optional, Union

import torch

from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.utils import LazyLoader
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
Expand All @@ -29,6 +31,29 @@
logger = init_logger(__name__)


def _walk_json_for_additional_properties(data: object):
if isinstance(data, dict):
for value in data.values():
_walk_json_for_additional_properties(value)
if 'additionalProperties' not in data and \
('properties' in data or 'patternProperties' in data):
data['additionalProperties'] = False
elif isinstance(data, list):
for item in data:
_walk_json_for_additional_properties(item)
Comment on lines +34 to +43
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might want to limit recursion depth to some reasonable amount here since it's processing untrusted user input. It could be abused to crash a vllm instance by causing excessive recursion depth. Google says the default recursion depth limit is 1000-ish, so that would be pretty trivial to exploit.

Or we can catch RecursionError and handle it cleanly, I suppose ... but a reasonable hard limit seems safer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, good point. I wonder where else a deeply-nested JSON guide could break vllm. Guarding against / limiting the depth of the guide JSON seems like something that should be handled consistently across backends/options higher in the stack.

Copy link
Contributor Author

@tjohnson31415 tjohnson31415 Apr 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I generated a deeply nested JSON schema and ran some tests. vLLM V1 actually handles the RecursionError and doesn't crash; the result is just a 400 or 500 back to the request. There's a few ways that it manifests:

  • When using the chat API with tools, a RecursionError is raised during chat templating when tojson is called in the template
  • With the "auto" backend in V1, the error is raised from has_xgrammar_unsupported_json_features()
  • With the guidance backend in V1, a ValueError is raised from validate_guidance_grammar() which calls serialize_guidance_grammar() but doesn't set any options so it doesn't hit it from the new code

However, the V0 engine does crash when using v1/completions. In addition to the new recursion with guidance:no-additional-properties, there are RecursionError raised in calls to has_*_unsupported_json_features() that can crash it.

So I don't think adding recursion here makes anything worse. IMO, there can be a follow-up PR to limit the depth of the JSON schema in one place.



def process_for_additional_properties(
guide_json: Union[str, dict[str, Any]]) -> dict[str, Any]:
if isinstance(guide_json, str):
guide_json_obj = json.loads(guide_json)
else:
# copy for modifications
guide_json_obj = copy.deepcopy(guide_json)
_walk_json_for_additional_properties(guide_json_obj)
return guide_json_obj


class GuidanceBackend(StructuredOutputBackend):

def __init__(self, vllm_config: VllmConfig):
Expand All @@ -41,9 +66,20 @@ def __init__(self, vllm_config: VllmConfig):
tokenizer_group.ping()
self.vllm_config = vllm_config
self.vocab_size = vllm_config.model_config.get_vocab_size()
self.disable_any_whitespace = (
"disable-any-whitespace"
in vllm_config.decoding_config.guided_decoding_backend)

self.disable_any_whitespace = False
self.no_additional_properties = False
backend_options = GuidedDecodingParams(
backend=vllm_config.decoding_config.guided_decoding_backend
).backend_options()
for option in backend_options:
if option == "disable-any-whitespace":
self.disable_any_whitespace = True
elif option == "no-additional-properties":
self.no_additional_properties = True
else:
raise ValueError(
f"Unsupported option for the guidance backend: {option}")

tokenizer = tokenizer_group.get_lora_tokenizer(None)
self.ll_tokenizer = llguidance_hf.from_tokenizer(
Expand All @@ -52,7 +88,8 @@ def __init__(self, vllm_config: VllmConfig):
def compile_grammar(self, request_type: StructuredOutputOptions,
grammar_spec: str) -> StructuredOutputGrammar:
self.serialized_grammar = serialize_guidance_grammar(
request_type, grammar_spec, self.disable_any_whitespace)
request_type, grammar_spec, self.disable_any_whitespace,
self.no_additional_properties)

ll_matcher = llguidance.LLMatcher(
self.ll_tokenizer,
Expand Down Expand Up @@ -129,10 +166,15 @@ def reset(self):
self.ll_matcher.reset()


def serialize_guidance_grammar(request_type: StructuredOutputOptions,
grammar_spec: str,
disable_any_whitespace: bool = False) -> str:
def serialize_guidance_grammar(
request_type: StructuredOutputOptions,
grammar_spec: Union[str, dict[str, Any]],
disable_any_whitespace: bool = False,
no_additional_properties: bool = False,
) -> str:
if request_type == StructuredOutputOptions.JSON:
if no_additional_properties:
grammar_spec = process_for_additional_properties(grammar_spec)
return llguidance.LLMatcher.grammar_from_json_schema(
grammar_spec,
defaults={
Expand Down
16 changes: 12 additions & 4 deletions vllm/v1/structured_output/backend_xgrammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import vllm.envs
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from vllm.utils import LazyLoader
Expand All @@ -32,16 +32,24 @@ class XgrammarBackend(StructuredOutputBackend):

def __init__(self, vllm_config: VllmConfig):
self.vllm_config = vllm_config
self.disable_any_whitespace = (
"disable-any-whitespace"
in vllm_config.decoding_config.guided_decoding_backend)
tokenizer_group = init_tokenizer_from_configs(
model_config=vllm_config.model_config,
scheduler_config=vllm_config.scheduler_config,
parallel_config=vllm_config.parallel_config,
lora_config=vllm_config.lora_config) # type: ignore[arg-type]
tokenizer_group.ping()

self.disable_any_whitespace = False
backend_options = GuidedDecodingParams(
backend=vllm_config.decoding_config.guided_decoding_backend
).backend_options()
for option in backend_options:
if option == "disable-any-whitespace":
self.disable_any_whitespace = True
else:
raise ValueError(
f"Unsupported option for the xgrammar backend: {option}")

tokenizer = tokenizer_group.get_lora_tokenizer(None)
self.vocab_size = vllm_config.model_config.get_vocab_size()
if isinstance(tokenizer, MistralTokenizer):
Expand Down